diff --git a/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py b/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py index bd46cc1b..eacd808f 100644 --- a/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py +++ b/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py @@ -46,7 +46,7 @@ # only search the first bigdl package and end up finding only one sub-package. from .bloom_cpp import bloom_load, bloom_free, bloom_run -from .bloom_cpp import bloom_tokenize, bloom_detokenize, bloom_forward, bloom_eval +from .bloom_cpp import bloom_tokenize, bloom_detokenize, bloom_forward, bloom_eval, bloom_embed from bigdl.llm.utils.common import invalidInputError from bigdl.llm.ggml.model.generation import GenerationMixin from typing import List, Optional, Generator, Sequence, Union @@ -127,7 +127,7 @@ class Bloom(GenerationMixin): # TODO: Some parameters are temporarily not supported unsupported_arg = {'n_parts': -1, 'n_gpu_layers': 0, 'f16_kv': True, 'logits_all': False, 'vocab_only': False, 'use_mmap': True, 'use_mlock': False, - 'embedding': False, 'last_n_tokens_size': 64, 'lora_base': None, + 'last_n_tokens_size': 64, 'lora_base': None, 'lora_path': None, 'verbose': True} for arg in unsupported_arg.keys(): invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}" @@ -156,44 +156,50 @@ class Bloom(GenerationMixin): ): # TODO: Some parameters are temporarily not supported # Unsupported parameters are checked in `_supported_call` - return self._supported_call(prompt, max_tokens, stream, stop, - suffix, temperature, top_p, logprobs, echo, frequency_penalty, + return self._supported_call(prompt, max_tokens, stream, stop, echo, model, + suffix, temperature, top_p, logprobs, frequency_penalty, presence_penalty, repeat_penalty, top_k, tfs_z, mirostat_mode, - mirostat_tau, mirostat_eta, model) + mirostat_tau, mirostat_eta) def _supported_call(self, prompt: str, max_tokens: int, stream: bool = False, - stop: Optional[List[str]] = [], *args): + stop: Optional[List[str]] = [], echo: bool = False, + model: Optional[str] = None, *args): # Check unsupporeted parameters - unsupported_arg = ['suffix', 'temperature', 'top_p', 'logprobs', 'echo', + unsupported_arg = ['suffix', 'temperature', 'top_p', 'logprobs', 'frequency_penalty', 'presence_penalty', 'repeat_penalty', 'top_k', 'tfs_z', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'model'] defult_value = {'suffix': None, 'temperature': 0.8, 'top_p': 0.95, 'logprobs': None, - 'echo': False, 'frequency_penalty': 0.0, 'presence_penalty': 0.0, + 'frequency_penalty': 0.0, 'presence_penalty': 0.0, 'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0, - 'mirostat_tau': 5.0, 'mirostat_eta': 0.1, 'model': None} + 'mirostat_tau': 5.0, 'mirostat_eta': 0.1} for index in range(len(args)): invalidInputError(args[index] == defult_value[unsupported_arg[index]], f"The parameter {unsupported_arg[index]} is temporarily " "unsupported, please use the default value.") if stream: - return self.stream(prompt, max_tokens, stop) + return self.stream(prompt, max_tokens, stop, echo, model) else: - return self._eval(prompt, max_tokens, False, stop) + return self._eval(prompt, max_tokens, False, stop, echo, model) def _eval(self, prompt: str, max_tokens: int, match_str: bool, - stop: Optional[List[str]] = []): + stop: Optional[List[str]] = [], echo: bool = False, model: Optional[str] = None): completion_id: str = f"cmpl-{str(uuid.uuid4())}" created: int = int(time.time()) + if model is None: + model_name = self.model_path + else: + model_name = model + prompt_len = len(self.tokenize(prompt)) if prompt.endswith("") or max_tokens < 1: return { "id": completion_id, "object": "text_completion", "created": created, - "model": self.model_path, + "model": model_name, "choices": [ { - "text": prompt, + "text": prompt if echo else "", "index": 0, "logprobs": None, "finish_reason": "length", @@ -201,15 +207,14 @@ class Bloom(GenerationMixin): ], "usage": { - # TODO: need tokenize - "prompt_tokens": None, - "completion_tokens": None, - "total_tokens": None, + "prompt_tokens": prompt_len, + "completion_tokens": 0, + "total_tokens": prompt_len, } } # use `buf` to store prompt and generated string, # assume the average length of words is less than 20 bytes - buf = bytes((len(prompt) + max_tokens) * 20) + buf = bytes((prompt_len + max_tokens) * 20) ret = bloom_run(ctx=self.ctx, seed=self.seed, n_threads=self.n_threads, @@ -229,13 +234,14 @@ class Bloom(GenerationMixin): finish_reason = "stop" else: finish_reason = None + completion_len = len(self.tokenize(split_text)) return {"id": completion_id, "object": "text_completion", "created": created, - "model": self.model_path, + "model": model_name, "choices": [ { - "text": prompt + split_text, + "text": prompt + split_text if echo else split_text, "index": 0, "logprobs": None, "finish_reason": finish_reason, @@ -243,25 +249,31 @@ class Bloom(GenerationMixin): ], "usage": { - # TODO: need tokenize - "prompt_tokens": None, - "completion_tokens": None, - "total_tokens": None, + "prompt_tokens": prompt_len, + "completion_tokens": completion_len, + "total_tokens": prompt_len + completion_len, } } - def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = []): + def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = [], + echo: bool = False, model: Optional[str] = None): completion_id: str = f"cmpl-{str(uuid.uuid4())}" created: int = int(time.time()) + if model is None: + model_name = self.model_path + else: + model_name = model + prompt_tokens: List[int] = self.tokenize(prompt.encode("utf-8")) + prompt_len = len(prompt_tokens) if prompt.endswith("") or max_tokens < 1: yield { "id": completion_id, "object": "text_completion", "created": created, - "model": self.model_path, + "model": model_name, "choices": [ { - "text": prompt, + "text": prompt if echo else "", "index": 0, "logprobs": None, "finish_reason": "length", @@ -269,35 +281,35 @@ class Bloom(GenerationMixin): ], "usage": { - # TODO: need tokenize - "prompt_tokens": None + "prompt_tokens": prompt_len } } else: for i in range(max_tokens): - if prompt.endswith(""): + token = self.forward(prompt_tokens) + prompt_tokens.append(token) + text = self.detokenize([token]).decode("utf-8", errors="ignore") + if text.endswith(""): + print('\n') break - else: - prompt = self._eval(prompt, 1, i != 0, stop)['choices'][0]['text'] - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": self.model_path, - "choices": [ - { - "text": prompt, - "index": 0, - "logprobs": None, - "finish_reason": None, - } - ], - "usage": - { - # TODO: need tokenize - "prompt_tokens": None + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": text, + "index": 0, + "logprobs": None, + "finish_reason": None, } + ], + "usage": + { + "prompt_tokens": prompt_len } + } def free(self): bloom_free(self.ctx) @@ -403,9 +415,11 @@ class Bloom(GenerationMixin): if tokens_or_none is not None: tokens.extend(tokens_or_none) - def embed(self, prompt: Union[str, bytes]) -> List[float]: + def embed(self, input: str) -> List[float]: """Only used for langchain""" - input_ids = self.tokenize(prompt) + invalidInputError(self.embedding, + "Bloom model must be created with embedding=True to call this method.") + input_ids = self.tokenize(input) return bloom_embed(ctx=self.ctx, input_ids=input_ids, seed=self.seed,