diff --git a/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py b/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py index ba903c1e..242962f2 100644 --- a/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py +++ b/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py @@ -47,6 +47,9 @@ from .bloom_cpp import bloom_load, bloom_free, bloom_run from bigdl.llm.utils.common import invalidInputError +from typing import List, Optional +import time +import uuid class Bloom: @@ -81,6 +84,7 @@ class Bloom: Returns: A Bloom instance. """ + self.model_path = model_path self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads) invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}") self.n_ctx = n_ctx @@ -91,15 +95,39 @@ class Bloom: self.last_n_tokens_size = last_n_tokens_size self.verbose = verbose - def __call__(self, prompt: str, max_tokens: int, stream: bool = False): + def __call__(self, prompt: str, max_tokens: int, stream: bool = False, + stop: Optional[List[str]] = []): if stream: - return self.stream(prompt, max_tokens) + return self.stream(prompt, max_tokens, stop) else: - return self._eval(prompt, max_tokens, False) + return self._eval(prompt, max_tokens, False, stop) - def _eval(self, prompt: str, max_tokens: int, match_str: bool): + def _eval(self, prompt: str, max_tokens: int, match_str: bool, + stop: Optional[List[str]] = []): + completion_id: str = f"cmpl-{str(uuid.uuid4())}" + created: int = int(time.time()) if prompt.endswith("") or max_tokens < 1: - return prompt + return { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": self.model_path, + "choices": [ + { + "text": prompt, + "index": 0, + "logprobs": None, + "finish_reason": "length", + } + ], + "usage": + { + # TODO: need tokenize + "prompt_tokens": None, + "completion_tokens": None, + "total_tokens": None, + } + } # use `buf` to store prompt and generated string, # assume the average length of words is less than 20 bytes buf = bytes((len(prompt) + max_tokens) * 20) @@ -112,18 +140,85 @@ class Bloom: prompt=bytes(prompt, encoding='utf-8'), buf=buf) s = str(buf, encoding='utf-8').rstrip("\x00") - return s - def stream(self, prompt: str, max_tokens: int): + text = s.split(prompt)[1] + split_text = text + if stop != []: + for stop_word in stop: + split_text = split_text.split(stop_word)[0] + if split_text != text: + finish_reason = "stop" + else: + finish_reason = None + return {"id": completion_id, + "object": "text_completion", + "created": created, + "model": self.model_path, + "choices": [ + { + "text": prompt + split_text, + "index": 0, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + "usage": + { + # TODO: need tokenize + "prompt_tokens": None, + "completion_tokens": None, + "total_tokens": None, + } + } + + def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = []): + completion_id: str = f"cmpl-{str(uuid.uuid4())}" + created: int = int(time.time()) if prompt.endswith("") or max_tokens < 1: - yield prompt + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": self.model_path, + "choices": [ + { + "text": prompt, + "index": 0, + "logprobs": None, + "finish_reason": "length", + } + ], + "usage": + { + # TODO: need tokenize + "prompt_tokens": None + } + } else: for i in range(max_tokens): if prompt.endswith(""): break else: - prompt = self._eval(prompt, 1, i != 0) - yield prompt + prompt = self._eval(prompt, 1, i != 0, stop)['choices'][0]['text'] + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": self.model_path, + "choices": [ + { + "text": prompt, + "index": 0, + "logprobs": None, + "finish_reason": None, + } + ], + "usage": + { + # TODO: need tokenize + "prompt_tokens": None + } + } def free(self): bloom_free(self.ctx)