From 50af0251e4025c2adc4a172bc4e0d7258fde7a54 Mon Sep 17 00:00:00 2001 From: Ruonan Wang <105281011+rnwang04@users.noreply.github.com> Date: Wed, 21 Jun 2023 13:23:06 +0800 Subject: [PATCH] LLM: First commit of StarCoder pybinding (#8354) * first commit of starcoder * update setup.py and fix style * add starcoder_cpp, fix style * fix style * support windows binary * update pybinding * fix style, add avx2 binary * small fix * fix style --- python/llm/setup.py | 15 +- .../llm/ggml/model/starcoder/__init__.py | 22 + .../llm/ggml/model/starcoder/starcoder.py | 433 ++++++++++++++++++ .../llm/ggml/model/starcoder/starcoder_cpp.py | 239 ++++++++++ 4 files changed, 707 insertions(+), 2 deletions(-) create mode 100644 python/llm/src/bigdl/llm/ggml/model/starcoder/__init__.py create mode 100644 python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py create mode 100644 python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder_cpp.py diff --git a/python/llm/setup.py b/python/llm/setup.py index e0aa06e6..07508840 100644 --- a/python/llm/setup.py +++ b/python/llm/setup.py @@ -66,12 +66,15 @@ def obtain_lib_urls(): base_url = "https://sourceforge.net/projects/analytics-zoo/files/bigdl-llm/" windows_binarys = ["llama.dll", "gptneox.dll", "bloom.dll", "quantize-llama.exe", "quantize-gptneox.exe", "quantize-bloom.exe", - "main-llama.exe", "main-gptneox.exe", "main-bloom.exe"] + "main-llama.exe", "main-gptneox.exe", "main-bloom.exe", + "starcoder.dll", "quantize-starcoder.exe", "main-starcoder.exe"] linux_binarys = ["libllama_avx2.so", "libgptneox_avx2.so", "libbloom_avx2.so", "libllama_avx512.so", "libgptneox_avx512.so", "libbloom_avx512.so", "quantize-llama", "quantize-gptneox", "quantize-bloom", "main-llama_avx2", "main-gptneox_avx2", "main-bloom_avx2", - "main-llama_avx512", "main-gptneox_avx512", "main-bloom_avx512"] + "main-llama_avx512", "main-gptneox_avx512", "main-bloom_avx512", + "libstarcoder_avx512.so", "main-starcoder_avx512", "quantize-starcoder", + "libstarcoder_avx2.so", "main-starcoder_avx2"] def get_date_urls(base_url): # obtain all urls based on date(format: xxxx-xx-xx) @@ -142,6 +145,9 @@ def setup_package(): "libs/main-bloom.exe", "libs/main-gptneox.exe", "libs/main-llama.exe", + "libs/main-starcoder.exe", + "libs/starcoder.dll", + "libs/quantize-starcoder.exe", ] package_data["Linux"] = [ "libs/libllama_avx2.so", @@ -153,12 +159,17 @@ def setup_package(): "libs/libbloom_avx2.so", "libs/libbloom_avx512.so", "libs/quantize-bloom", + "libs/libstarcoder_avx512.so", + "libs/libstarcoder_avx2.so", + "libs/quantize-starcoder", "libs/main-bloom_avx2", "libs/main-bloom_avx512", "libs/main-gptneox_avx2", "libs/main-gptneox_avx512", "libs/main-llama_avx2", "libs/main-llama_avx512", + "libs/main-starcoder_avx512", + "libs/main-starcoder_avx2", ] platform_name = None diff --git a/python/llm/src/bigdl/llm/ggml/model/starcoder/__init__.py b/python/llm/src/bigdl/llm/ggml/model/starcoder/__init__.py new file mode 100644 index 00000000..414e261a --- /dev/null +++ b/python/llm/src/bigdl/llm/ggml/model/starcoder/__init__.py @@ -0,0 +1,22 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This would makes sure Python is aware there is more than one sub-package within bigdl, +# physically located elsewhere. +# Otherwise there would be module not found error in non-pip's setting as Python would +# only search the first bigdl package and end up finding only one sub-package. + +from .starcoder import Starcoder diff --git a/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py b/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py new file mode 100644 index 00000000..b8d20a71 --- /dev/null +++ b/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py @@ -0,0 +1,433 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# =========================================================================== +# +# This file is adapted from +# https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama.py +# +# MIT License +# +# Copyright (c) 2023 Andrei Betlen +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# This would makes sure Python is aware there is more than one sub-package within bigdl, +# physically located elsewhere. +# Otherwise there would be module not found error in non-pip's setting as Python would +# only search the first bigdl package and end up finding only one sub-package. + +from .starcoder_cpp import starcoder_load, starcoder_free, starcoder_run +from .starcoder_cpp import starcoder_tokenize, starcoder_detokenize +from .starcoder_cpp import starcoder_forward, starcoder_eval, starcoder_embed +from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.ggml.model.generation import GenerationMixin +from typing import List, Optional, Generator, Sequence, Union +import time +import uuid + + +class Starcoder(GenerationMixin): + """High-level Python wrapper for a quantized starcoder model.""" + + def __init__( + self, + model_path: str, + n_ctx: int = 512, + n_parts: int = -1, + n_gpu_layers: int = 0, + seed: int = -1, + f16_kv: bool = True, + logits_all: bool = False, + vocab_only: bool = False, + use_mmap: bool = True, + use_mlock: bool = False, + embedding: bool = False, + n_threads: Optional[int] = 2, + n_batch: int = 512, + last_n_tokens_size: int = 64, + lora_base: Optional[str] = None, + lora_path: Optional[str] = None, + verbose: bool = True, + ): + """Load a quantized starcoder model from `model_path`. + + Args: + model_path: Path to the model. + n_ctx: Maximum context size. + n_parts: Number of parts to split the model into. If -1, the number of parts + is automatically determined. + seed: Random seed. For default value -1, current timestamp is used as seed. + f16_kv: Use half-precision for key/value cache. + logits_all: Return logits for all tokens, not just the last token. + vocab_only: Only load the vocabulary no weights. + use_mmap: Use mmap if possible. + use_mlock: Force the system to keep the model in RAM. + embedding: Embedding mode only. + n_threads: Number of threads to use. Default to be 2. + n_batch: Maximum number of prompt tokens to batch together when calling starcoder_eval. + last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. + lora_base: Optional path to base model, useful if using a quantized base model and + you want to apply LoRA to an f16 model. + lora_path: Path to a LoRA file to apply to the model. + verbose: Print verbose output to stderr. + + Raises: + ValueError: If the model path does not exist. + + Returns: + A Starcoder instance. + """ + self.model_path = model_path + self.ctx = starcoder_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads) + invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}") + self.n_ctx = n_ctx + self.n_parts = n_parts + self.n_gpu_layers = n_gpu_layers + self.f16_kv = f16_kv + self.seed = seed + self.logits_all = logits_all + self.vocab_only = vocab_only + self.use_mmap = use_mmap + self.use_mlock = use_mlock + self.embedding = embedding + self.n_threads = n_threads + self.n_batch = n_batch + self.last_n_tokens_size = last_n_tokens_size + self.lora_base = lora_base + self.lora_path = lora_path + self.verbose = verbose + # TODO: Some parameters are temporarily not supported + unsupported_arg = {'n_parts': -1, 'n_gpu_layers': 0, 'f16_kv': True, 'logits_all': False, + 'vocab_only': False, 'use_mmap': True, 'use_mlock': False, + 'last_n_tokens_size': 64, 'lora_base': None, + 'lora_path': None, 'verbose': True} + for arg in unsupported_arg.keys(): + invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}" + " is temporarily unsupported, please use the default value.") + + def __call__( + self, + prompt: str, + suffix: Optional[str] = None, + max_tokens: int = 128, + temperature: float = 0.8, + top_p: float = 0.95, + logprobs: Optional[int] = None, + echo: bool = False, + stop: Optional[Union[str, List[str]]]=[], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + repeat_penalty: float = 1.1, + top_k: int = 40, + stream: bool = False, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + ): + # TODO: Some parameters are temporarily not supported + # Unsupported parameters are checked in `_supported_call` + return self._supported_call(prompt, max_tokens, stream, stop, echo, model, + suffix, temperature, top_p, logprobs, frequency_penalty, + presence_penalty, repeat_penalty, top_k, tfs_z, mirostat_mode, + mirostat_tau, mirostat_eta) + + def _supported_call(self, prompt: str, max_tokens: int, stream: bool = False, + stop: Optional[List[str]] = [], echo: bool = False, + model: Optional[str] = None, *args): + # Check unsupporeted parameters + unsupported_arg = ['suffix', 'temperature', 'top_p', 'logprobs', + 'frequency_penalty', 'presence_penalty', 'repeat_penalty', 'top_k', + 'tfs_z', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'model'] + defult_value = {'suffix': None, 'temperature': 0.8, 'top_p': 0.95, 'logprobs': None, + 'frequency_penalty': 0.0, 'presence_penalty': 0.0, + 'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0, + 'mirostat_tau': 5.0, 'mirostat_eta': 0.1} + for index in range(len(args)): + invalidInputError(args[index] == defult_value[unsupported_arg[index]], + f"The parameter {unsupported_arg[index]} is temporarily " + "unsupported, please use the default value.") + + if stream: + return self.stream(prompt, max_tokens, stop, echo, model) + else: + return self._eval(prompt, max_tokens, False, stop, echo, model) + + def _eval(self, prompt: str, max_tokens: int, match_str: bool, + stop: Optional[List[str]] = [], echo: bool = False, + model: Optional[str] = None): + completion_id: str = f"cmpl-{str(uuid.uuid4())}" + created: int = int(time.time()) + if model is None: + model_name = self.model_path + else: + model_name = model + prompt_len = len(self.tokenize(prompt)) + if prompt.endswith("<|endoftext|>") or max_tokens < 1: + return { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": prompt if echo else "", + "index": 0, + "logprobs": None, + "finish_reason": "length", + } + ], + "usage": + { + "prompt_tokens": prompt_len, + "completion_tokens": 0, + "total_tokens": prompt_len, + } + } + # use `buf` to store prompt and generated string, + # assume the average length of words is less than 20 bytes + buf = bytes((prompt_len + max_tokens) * 20) + ret = starcoder_run(ctx=self.ctx, + seed=self.seed, + n_threads=self.n_threads, + n_batch=self.n_batch, + n_predict=max_tokens, + match_str=match_str, + prompt=bytes(prompt, encoding='utf-8'), + buf=buf) + s = str(buf, encoding='utf-8').rstrip("\x00") + + text = s.split(prompt)[1] + split_text = text + if stop != []: + for stop_word in stop: + split_text = split_text.split(stop_word)[0] + if split_text != text: + finish_reason = "stop" + else: + finish_reason = None + completion_len = len(self.tokenize(split_text)) + return {"id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": prompt + split_text if echo else split_text, + "index": 0, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + "usage": + { + "prompt_tokens": prompt_len, + "completion_tokens": completion_len, + "total_tokens": prompt_len + completion_len, + } + } + + def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = [], + echo: bool = False, model: Optional[str] = None): + completion_id: str = f"cmpl-{str(uuid.uuid4())}" + created: int = int(time.time()) + if model is None: + model_name = self.model_path + else: + model_name = model + prompt_tokens: List[int] = self.tokenize(prompt.encode("utf-8")) + prompt_len = len(prompt_tokens) + if prompt.endswith("<|endoftext|>") or max_tokens < 1: + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": prompt if echo else "", + "index": 0, + "logprobs": None, + "finish_reason": "length", + } + ], + "usage": + { + "prompt_tokens": prompt_len + } + } + else: + for i in range(max_tokens): + token = self.forward(prompt_tokens) + prompt_tokens.append(token) + text = self.detokenize([token]).decode("utf-8", errors="ignore") + if text.endswith("<|endoftext|>"): + print('\n') + else: + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": text, + "index": 0, + "logprobs": None, + "finish_reason": None, + } + ], + "usage": + { + "prompt_tokens": prompt_len + } + } + + def free(self): + starcoder_free(self.ctx) + + def _tokenize(self, text: bytes, add_bos: bool = False) -> List[int]: + """Tokenize a string. + + Args: + text: The utf-8 encoded string to tokenize. + + Raises: + RuntimeError: If the tokenization failed. + + Returns: + A list of tokens. + """ + invalidInputError(self.ctx is not None, + "The attribute `ctx` of `Starcoder` object is None.") + return starcoder_tokenize(self.ctx, text, False) + + def detokenize(self, tokens: List[int]) -> bytes: + """Detokenize a list of tokens. + + Args: + tokens: The list of tokens to detokenize. + + Returns: + The detokenized string. + """ + invalidInputError(self.ctx is not None, + "The attribute `ctx` of `Starcoder` object is None.") + output = "" + for token in tokens: + output += starcoder_detokenize(self.ctx, token) + return output.encode('utf-8') + + def forward(self, input_ids: List[int]) -> int: + return starcoder_forward(ctx=self.ctx, + input_ids=input_ids, + seed=self.seed, + n_threads=self.n_threads, + n_batch=self.n_batch) + + def eval(self, input_ids: List[int]) -> List[List[float]]: + """Only used for testing accuracy""" + return starcoder_eval(ctx=self.ctx, + input_ids=input_ids, + seed=self.seed, + n_threads=self.n_threads, + n_batch=len(input_ids)) + + def _generate( + self, + tokens: Sequence[int], + top_k: int = 40, + top_p: float = 0.95, + temp: float = 0.80, + repeat_penalty: float = 1.1, + reset: bool = True, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + ) -> Generator[int, Optional[Sequence[int]], None]: + """Create a generator of tokens from a prompt. + + Examples: + >>> llm = Starcoder(your_model_path) + >>> tokens = llm._tokenize(b"Learning English is") + >>> for token in llm._generate(tokens): + >>> print(llm.detokenize([token]).decode("utf-8", errors="ignore")) + + Args: + tokens: The prompt tokens. + + Yields: + The generated tokens. + """ + # TODO: Some parameters are temporarily not supported + # Unsupported parameters are checked in `_supported_generate` + return self._supported_generate(tokens, top_k, top_p, temp, repeat_penalty, reset, + frequency_penalty, presence_penalty, tfs_z, mirostat_mode, + mirostat_tau, mirostat_eta) + + def _supported_generate(self, tokens: Sequence[int], *args): + # Check unsupporeted parameters + unsupported_arg = ['top_k', 'top_p', 'temp', 'repeat_penalty', 'reset', + 'frequency_penalty', 'presence_penalty', 'tfs_z', 'mirostat_mode', + 'mirostat_tau', 'mirostat_eta'] + defult_value = {'top_k': 40, 'top_p': 0.95, 'temp': 0.80, 'repeat_penalty': 1.1, + 'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0, + 'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1} + for index in range(len(args)): + invalidInputError(args[index] == defult_value[unsupported_arg[index]], + f"The parameter {unsupported_arg[index]} is temporarily " + "unsupported, please use the default value.") + + invalidInputError(self.ctx is not None, + "The attribute `ctx` of `Starcoder` object is None.") + while True: + token = self.forward(tokens) + tokens_or_none = yield token + tokens.append(token) + if tokens_or_none is not None: + tokens.extend(tokens_or_none) + + def embed(self, input: str) -> List[float]: + """Only used for langchain""" + invalidInputError(self.embedding, + "Starcoder model must be created with embedding=True" + "to call this method.") + input_ids = self.tokenize(input) + return starcoder_embed(ctx=self.ctx, + input_ids=input_ids, + seed=self.seed, + n_threads=self.n_threads, + n_batch=len(input_ids)) diff --git a/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder_cpp.py b/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder_cpp.py new file mode 100644 index 00000000..7de757db --- /dev/null +++ b/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder_cpp.py @@ -0,0 +1,239 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# =========================================================================== +# +# This file is adapted from +# https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama_cpp.py +# +# MIT License +# +# Copyright (c) 2023 Andrei Betlen +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# This would makes sure Python is aware there is more than one sub-package within bigdl, +# physically located elsewhere. +# Otherwise there would be module not found error in non-pip's setting as Python would +# only search the first bigdl package and end up finding only one sub-package. + +import sys +import os +import ctypes +from typing import List +from ctypes import ( + c_int, + c_long, + c_float, + c_char_p, + c_void_p, + c_bool, + POINTER, + pointer, + Structure, + Array, + c_uint8, + c_size_t, +) +import pathlib +from bigdl.llm.utils import get_avx_flags +from bigdl.llm.utils.common import invalidInputError + + +# Load the library +def _load_shared_library(lib_base_name: str): + # Determine the file extension based on the platform + if sys.platform.startswith("linux") or sys.platform == "darwin": + lib_ext = ".so" + elif sys.platform == "win32": + lib_ext = ".dll" + else: + invalidInputError(False, "Unsupported platform") + + avx = get_avx_flags() + + # Construct the paths to the possible shared library names (python/llm/src/bigdl/llm/libs) + _base_path = pathlib.Path(__file__).parent.parent.parent.parent.resolve() + _base_path = _base_path / 'libs' + # Searching for the library in the current directory under the name "libbloom" (default name + # for bloomcpp) and "bloom" (default name for this repo) + _lib_paths = [ + _base_path / f"lib{lib_base_name}{avx}{lib_ext}", + _base_path / f"{lib_base_name}{avx}{lib_ext}", + ] + + if "STARCODER_CPP_LIB" in os.environ: + lib_base_name = os.environ["STARCODER_CPP_LIB"] + _lib = pathlib.Path(lib_base_name) + _base_path = _lib.parent.resolve() + _lib_paths = [_lib.resolve()] + + # Add the library directory to the DLL search path on Windows (if needed) + if sys.platform == "win32" and sys.version_info >= (3, 8): + os.add_dll_directory(str(_base_path)) + + # Try to load the shared library, handling potential errors + for _lib_path in _lib_paths: + if _lib_path.exists(): + try: + return ctypes.CDLL(str(_lib_path)) + except Exception as e: + invalidInputError(False, + f"Failed to load shared library '{_lib_path}': {e}") + + invalidInputError(False, f"Shared library with base name '{lib_base_name}' not found") + + +# Specify the base name of the shared library to load +_lib_base_name = "starcoder" + +# Load the library +_lib = _load_shared_library(_lib_base_name) + + +def c_free(p: c_void_p): + _lib.c_free(p) + + +_lib.c_free.argtypes = [c_void_p] +_lib.c_free.restype = None + + +def starcoder_load(fname: bytes, n_ctx: c_int, n_threads: c_int) -> c_void_p: + return _lib.starcoder_load(fname, n_ctx, n_threads) + + +_lib.starcoder_load.argtypes = [c_char_p, c_int, c_int] +_lib.starcoder_load.restype = c_void_p + + +def starcoder_free(ctx: c_void_p): + return _lib.starcoder_free(ctx) + + +_lib.starcoder_free.argtypes = [c_void_p] +_lib.starcoder_free.restype = None + + +def starcoder_run(ctx: c_void_p, + seed: c_int, + n_threads: c_int, + n_batch: c_int, + n_predict: c_int, + match_str: c_bool, + prompt: bytes, + buf: bytes) -> c_int: + return _lib.starcoder_run(ctx, seed, n_threads, n_batch, n_predict, match_str, prompt, buf) + + +_lib.starcoder_run.argtypes = [c_void_p, c_int, c_int, c_int, c_int, c_bool, c_char_p, c_char_p] +_lib.starcoder_run.restype = c_int + + +def starcoder_tokenize(ctx: c_void_p, + prompt: bytes, + bos: bool = False) -> List[int]: + n_tokens = c_int(0) + c_tokens = _lib.tokenize_api(ctx, prompt, bos, pointer(n_tokens)) + tokens = [c_tokens[i] for i in range(0, n_tokens.value)] + c_free(c_tokens) + return tokens + + +_lib.tokenize_api.argtypes = [c_void_p, c_char_p, c_bool, c_void_p] +_lib.tokenize_api.restype = POINTER(c_int) + + +def starcoder_detokenize(ctx: c_void_p, + token_id: c_int) -> str: + c_chars = _lib.detokenize_api(ctx, token_id) + s = c_chars.decode('utf-8') + return s + + +_lib.detokenize_api.argtypes = [c_void_p, c_int] +_lib.detokenize_api.restype = c_char_p + + +def starcoder_eval(ctx: c_void_p, + input_ids: List[int], + seed: c_int, + n_threads: c_int, + n_batch: c_int) -> List[List[float]]: + length = len(input_ids) + c_input_ids = (c_int * length)(*input_ids) + n_logits = c_long(0) + c_logits = _lib.eval_api(ctx, c_input_ids, length, seed, n_threads, n_batch, pointer(n_logits)) + n_vocab = n_logits.value // length + assert(n_vocab * length == n_logits.value) + logits = [[c_logits[i * n_vocab + j] for j in range(n_vocab)] for i in range(length)] + # do not free c_logits + return logits + + +_lib.eval_api.argtypes = [c_void_p, c_void_p, c_int, c_int, c_int, c_int, c_void_p] +_lib.eval_api.restype = POINTER(c_float) + + +def starcoder_embed(ctx: c_void_p, + input_ids: List[int], + seed: c_int, + n_threads: c_int, + n_batch: c_int) -> List[float]: + length = len(input_ids) + c_input_ids = (c_int * length)(*input_ids) + n_embd = c_long(0) + c_embeddings = _lib.embed_api(ctx, c_input_ids, length, seed, n_threads, + n_batch, pointer(n_embd)) + embeddings = [c_embeddings[i] for i in range(n_embd.value)] + # do not free c_embeddings + return embeddings + + +_lib.embed_api.argtypes = [c_void_p, c_void_p, c_int, c_int, c_int, c_int, c_void_p] +_lib.embed_api.restype = POINTER(c_float) + + +def starcoder_forward(ctx: c_void_p, + input_ids: List[int], + seed: c_int, + n_threads: c_int, + n_batch: c_int) -> int: + length = len(input_ids) + c_input_ids = (c_int * length)(*input_ids) + token_id = _lib.forward_api(ctx, c_input_ids, length, seed, n_threads, n_batch) + return token_id + + +_lib.forward_api.argtypes = [c_void_p, c_void_p, c_int, c_int, c_int, c_int] +_lib.forward_api.restype = c_int + +# ------------------------------------------------------------------- #