diff --git a/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py b/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py index 09432482..bd46cc1b 100644 --- a/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py +++ b/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py @@ -46,35 +46,57 @@ # only search the first bigdl package and end up finding only one sub-package. from .bloom_cpp import bloom_load, bloom_free, bloom_run +from .bloom_cpp import bloom_tokenize, bloom_detokenize, bloom_forward, bloom_eval from bigdl.llm.utils.common import invalidInputError -from typing import List, Optional +from bigdl.llm.ggml.model.generation import GenerationMixin +from typing import List, Optional, Generator, Sequence, Union import time import uuid -class Bloom: +class Bloom(GenerationMixin): """High-level Python wrapper for a bloom.cpp model.""" - def __init__(self, - model_path: str, - n_ctx: int = 512, - seed: int = 1337, - logits_all: bool = False, - n_threads: int = 2, - n_batch: int = 8, - last_n_tokens_size: int = 64, - verbose: bool = True, - ): + def __init__( + self, + model_path: str, + n_ctx: int = 512, + n_parts: int = -1, + n_gpu_layers: int = 0, + seed: int = -1, + f16_kv: bool = True, + logits_all: bool = False, + vocab_only: bool = False, + use_mmap: bool = True, + use_mlock: bool = False, + embedding: bool = False, + n_threads: Optional[int] = 2, + n_batch: int = 512, + last_n_tokens_size: int = 64, + lora_base: Optional[str] = None, + lora_path: Optional[str] = None, + verbose: bool = True, + ): """Load a bloom.cpp model from `model_path`. Args: model_path: Path to the model. n_ctx: Maximum context size. - seed: Random seed. 0 for random. + n_parts: Number of parts to split the model into. If -1, the number of parts + is automatically determined. + seed: Random seed. For default value -1, current timestamp is used as seed. + f16_kv: Use half-precision for key/value cache. logits_all: Return logits for all tokens, not just the last token. + vocab_only: Only load the vocabulary no weights. + use_mmap: Use mmap if possible. + use_mlock: Force the system to keep the model in RAM. + embedding: Embedding mode only. n_threads: Number of threads to use. Default to be 2. - n_batch: Maximum number of prompt tokens to batch together when calling llama_eval. + n_batch: Maximum number of prompt tokens to batch together when calling bloom_eval. last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. + lora_base: Optional path to base model, useful if using a quantized base model and + you want to apply LoRA to an f16 model. + lora_path: Path to a LoRA file to apply to the model. verbose: Print verbose output to stderr. Raises: @@ -87,15 +109,73 @@ class Bloom: self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads) invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}") self.n_ctx = n_ctx + self.n_parts = n_parts + self.n_gpu_layers = n_gpu_layers + self.f16_kv = f16_kv self.seed = seed self.logits_all = logits_all + self.vocab_only = vocab_only + self.use_mmap = use_mmap + self.use_mlock = use_mlock + self.embedding = embedding self.n_threads = n_threads self.n_batch = n_batch self.last_n_tokens_size = last_n_tokens_size + self.lora_base = lora_base + self.lora_path = lora_path self.verbose = verbose + # TODO: Some parameters are temporarily not supported + unsupported_arg = {'n_parts': -1, 'n_gpu_layers': 0, 'f16_kv': True, 'logits_all': False, + 'vocab_only': False, 'use_mmap': True, 'use_mlock': False, + 'embedding': False, 'last_n_tokens_size': 64, 'lora_base': None, + 'lora_path': None, 'verbose': True} + for arg in unsupported_arg.keys(): + invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}" + " is temporarily unsupported, please use the default value.") + + def __call__( + self, + prompt: str, + suffix: Optional[str] = None, + max_tokens: int = 128, + temperature: float = 0.8, + top_p: float = 0.95, + logprobs: Optional[int] = None, + echo: bool = False, + stop: Optional[Union[str, List[str]]]=[], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + repeat_penalty: float = 1.1, + top_k: int = 40, + stream: bool = False, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + ): + # TODO: Some parameters are temporarily not supported + # Unsupported parameters are checked in `_supported_call` + return self._supported_call(prompt, max_tokens, stream, stop, + suffix, temperature, top_p, logprobs, echo, frequency_penalty, + presence_penalty, repeat_penalty, top_k, tfs_z, mirostat_mode, + mirostat_tau, mirostat_eta, model) + + def _supported_call(self, prompt: str, max_tokens: int, stream: bool = False, + stop: Optional[List[str]] = [], *args): + # Check unsupporeted parameters + unsupported_arg = ['suffix', 'temperature', 'top_p', 'logprobs', 'echo', + 'frequency_penalty', 'presence_penalty', 'repeat_penalty', 'top_k', + 'tfs_z', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'model'] + defult_value = {'suffix': None, 'temperature': 0.8, 'top_p': 0.95, 'logprobs': None, + 'echo': False, 'frequency_penalty': 0.0, 'presence_penalty': 0.0, + 'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0, + 'mirostat_tau': 5.0, 'mirostat_eta': 0.1, 'model': None} + for index in range(len(args)): + invalidInputError(args[index] == defult_value[unsupported_arg[index]], + f"The parameter {unsupported_arg[index]} is temporarily " + "unsupported, please use the default value.") - def __call__(self, prompt: str, max_tokens: int = 128, stream: bool = False, - stop: Optional[List[str]] = []): if stream: return self.stream(prompt, max_tokens, stop) else: @@ -221,3 +301,113 @@ class Bloom: def free(self): bloom_free(self.ctx) + + def _tokenize(self, text: bytes, add_bos: bool = False) -> List[int]: + """Tokenize a string. + + Args: + text: The utf-8 encoded string to tokenize. + + Raises: + RuntimeError: If the tokenization failed. + + Returns: + A list of tokens. + """ + invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.") + return bloom_tokenize(self.ctx, text, False) + + def detokenize(self, tokens: List[int]) -> bytes: + """Detokenize a list of tokens. + + Args: + tokens: The list of tokens to detokenize. + + Returns: + The detokenized string. + """ + invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.") + output = "" + for token in tokens: + output += bloom_detokenize(self.ctx, token) + return output.encode('utf-8') + + def forward(self, input_ids: List[int]) -> int: + return bloom_forward(ctx=self.ctx, + input_ids=input_ids, + seed=self.seed, + n_threads=self.n_threads, + n_batch=self.n_batch) + + def eval(self, input_ids: List[int]) -> List[List[float]]: + """Only used for testing accuracy""" + return bloom_eval(ctx=self.ctx, + input_ids=input_ids, + seed=self.seed, + n_threads=self.n_threads, + n_batch=len(input_ids)) + + def _generate( + self, + tokens: Sequence[int], + top_k: int = 40, + top_p: float = 0.95, + temp: float = 0.80, + repeat_penalty: float = 1.1, + reset: bool = True, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + ) -> Generator[int, Optional[Sequence[int]], None]: + """Create a generator of tokens from a prompt. + + Examples: + >>> llm = Bloom(your_model_path) + >>> tokens = llm._tokenize(b"Learning English is") + >>> for token in llm._generate(tokens): + >>> print(llm.detokenize([token]).decode("utf-8", errors="ignore")) + + Args: + tokens: The prompt tokens. + + Yields: + The generated tokens. + """ + # TODO: Some parameters are temporarily not supported + # Unsupported parameters are checked in `_supported_generate` + return self._supported_generate(tokens, top_k, top_p, temp, repeat_penalty, reset, + frequency_penalty, presence_penalty, tfs_z, mirostat_mode, + mirostat_tau, mirostat_eta) + + def _supported_generate(self, tokens: Sequence[int], *args): + # Check unsupporeted parameters + unsupported_arg = ['top_k', 'top_p', 'temp', 'repeat_penalty', 'reset', + 'frequency_penalty', 'presence_penalty', 'tfs_z', 'mirostat_mode', + 'mirostat_tau', 'mirostat_eta'] + defult_value = {'top_k': 40, 'top_p': 0.95, 'temp': 0.80, 'repeat_penalty': 1.1, + 'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0, + 'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1} + for index in range(len(args)): + invalidInputError(args[index] == defult_value[unsupported_arg[index]], + f"The parameter {unsupported_arg[index]} is temporarily " + "unsupported, please use the default value.") + + invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.") + while True: + token = self.forward(tokens) + tokens_or_none = yield token + tokens.append(token) + if tokens_or_none is not None: + tokens.extend(tokens_or_none) + + def embed(self, prompt: Union[str, bytes]) -> List[float]: + """Only used for langchain""" + input_ids = self.tokenize(prompt) + return bloom_embed(ctx=self.ctx, + input_ids=input_ids, + seed=self.seed, + n_threads=self.n_threads, + n_batch=len(input_ids)) diff --git a/python/llm/src/bigdl/llm/ggml/model/bloom/bloom_cpp.py b/python/llm/src/bigdl/llm/ggml/model/bloom/bloom_cpp.py index 9286ea8d..18b5d65c 100644 --- a/python/llm/src/bigdl/llm/ggml/model/bloom/bloom_cpp.py +++ b/python/llm/src/bigdl/llm/ggml/model/bloom/bloom_cpp.py @@ -48,13 +48,16 @@ import sys import os import ctypes +from typing import List from ctypes import ( c_int, + c_long, c_float, c_char_p, c_void_p, c_bool, POINTER, + pointer, Structure, Array, c_uint8, @@ -116,6 +119,14 @@ _lib_base_name = "bloom" _lib = _load_shared_library(_lib_base_name) +def c_free(p: c_void_p): + _lib.c_free(p) + + +_lib.c_free.argtypes = [c_void_p] +_lib.c_free.restype = None + + def bloom_load(fname: bytes, n_ctx: c_int, n_threads: c_int) -> c_void_p: return _lib.bloom_load(fname, n_ctx, n_threads) @@ -146,4 +157,83 @@ def bloom_run(ctx: c_void_p, _lib.bloom_run.argtypes = [c_void_p, c_int, c_int, c_int, c_int, c_bool, c_char_p, c_char_p] _lib.bloom_run.restype = c_int + +def bloom_tokenize(ctx: c_void_p, + prompt: bytes, + bos: bool = False) -> List[int]: + n_tokens = c_int(0) + c_tokens = _lib.tokenize_api(ctx, prompt, bos, pointer(n_tokens)) + tokens = [c_tokens[i] for i in range(0, n_tokens.value)] + c_free(c_tokens) + return tokens + + +_lib.tokenize_api.argtypes = [c_void_p, c_char_p, c_bool, c_void_p] +_lib.tokenize_api.restype = POINTER(c_int) + + +def bloom_detokenize(ctx: c_void_p, + token_id: c_int) -> str: + c_chars = _lib.detokenize_api(ctx, token_id) + s = c_chars.decode('utf-8') + return s + + +_lib.detokenize_api.argtypes = [c_void_p, c_int] +_lib.detokenize_api.restype = c_char_p + + +def bloom_eval(ctx: c_void_p, + input_ids: List[int], + seed: c_int, + n_threads: c_int, + n_batch: c_int) -> List[List[float]]: + length = len(input_ids) + c_input_ids = (c_int * length)(*input_ids) + n_logits = c_long(0) + c_logits = _lib.eval_api(ctx, c_input_ids, length, seed, n_threads, n_batch, pointer(n_logits)) + n_vocab = n_logits.value // length + assert(n_vocab * length == n_logits.value) + logits = [[c_logits[i * n_vocab + j] for j in range(n_vocab)] for i in range(length)] + # do not free c_logits + return logits + + +_lib.eval_api.argtypes = [c_void_p, c_void_p, c_int, c_int, c_int, c_int, c_void_p] +_lib.eval_api.restype = POINTER(c_float) + + +def bloom_embed(ctx: c_void_p, + input_ids: List[int], + seed: c_int, + n_threads: c_int, + n_batch: c_int) -> List[float]: + length = len(input_ids) + c_input_ids = (c_int * length)(*input_ids) + n_embd = c_long(0) + c_embeddings = _lib.embed_api(ctx, c_input_ids, length, seed, n_threads, + n_batch, pointer(n_embd)) + embeddings = [c_embeddings[i] for i in range(n_embd.value)] + # do not free c_embeddings + return embeddings + + +_lib.embed_api.argtypes = [c_void_p, c_void_p, c_int, c_int, c_int, c_int, c_void_p] +_lib.embed_api.restype = POINTER(c_float) + + +def bloom_forward(ctx: c_void_p, + input_ids: List[int], + seed: c_int, + n_threads: c_int, + n_batch: c_int) -> int: + length = len(input_ids) + c_input_ids = (c_int * length)(*input_ids) + token_id = _lib.forward_api(ctx, c_input_ids, length, seed, n_threads, n_batch) + return token_id + + +_lib.forward_api.argtypes = [c_void_p, c_void_p, c_int, c_int, c_int, c_int] +_lib.forward_api.restype = c_int + # ------------------------------------------------------------------- # diff --git a/python/llm/src/bigdl/llm/ggml/model/gptneox/gptneox.py b/python/llm/src/bigdl/llm/ggml/model/gptneox/gptneox.py index 398436c4..a4461b27 100644 --- a/python/llm/src/bigdl/llm/ggml/model/gptneox/gptneox.py +++ b/python/llm/src/bigdl/llm/ggml/model/gptneox/gptneox.py @@ -132,7 +132,7 @@ class Gptneox(GenerationMixin): n_ctx: int = 512, n_parts: int = -1, n_gpu_layers: int = 0, - seed: int = 1337, + seed: int = -1, f16_kv: bool = True, logits_all: bool = False, vocab_only: bool = False, @@ -153,7 +153,7 @@ class Gptneox(GenerationMixin): n_ctx: Maximum context size. n_parts: Number of parts to split the model into. If -1, the number of parts is automatically determined. - seed: Random seed. 0 for random. + seed: Random seed. For default value -1, current timestamp is used as seed. f16_kv: Use half-precision for key/value cache. logits_all: Return logits for all tokens, not just the last token. vocab_only: Only load the vocabulary no weights. diff --git a/python/llm/src/bigdl/llm/ggml/model/llama/llama.py b/python/llm/src/bigdl/llm/ggml/model/llama/llama.py index 669757d0..aafd87fc 100644 --- a/python/llm/src/bigdl/llm/ggml/model/llama/llama.py +++ b/python/llm/src/bigdl/llm/ggml/model/llama/llama.py @@ -130,7 +130,7 @@ class Llama(GenerationMixin): n_ctx: int = 512, n_parts: int = -1, n_gpu_layers: int = 0, - seed: int = 1337, + seed: int = -1, f16_kv: bool = True, logits_all: bool = False, vocab_only: bool = False, @@ -151,7 +151,7 @@ class Llama(GenerationMixin): n_ctx: Maximum context size. n_parts: Number of parts to split the model into. If -1, the number of parts is automatically determined. - seed: Random seed. 0 for random. + seed: Random seed. For default value -1, current timestamp is used as seed. f16_kv: Use half-precision for key/value cache. logits_all: Return logits for all tokens, not just the last token. vocab_only: Only load the vocabulary no weights.