LLM: add stop words and enhance output for bloom pybinding (#8280)
This commit is contained in:
parent
6990328e5c
commit
f9e2bda04a
1 changed files with 105 additions and 10 deletions
|
|
@ -47,6 +47,9 @@
|
|||
|
||||
from .bloom_cpp import bloom_load, bloom_free, bloom_run
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from typing import List, Optional
|
||||
import time
|
||||
import uuid
|
||||
|
||||
|
||||
class Bloom:
|
||||
|
|
@ -81,6 +84,7 @@ class Bloom:
|
|||
Returns:
|
||||
A Bloom instance.
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads)
|
||||
invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}")
|
||||
self.n_ctx = n_ctx
|
||||
|
|
@ -91,15 +95,39 @@ class Bloom:
|
|||
self.last_n_tokens_size = last_n_tokens_size
|
||||
self.verbose = verbose
|
||||
|
||||
def __call__(self, prompt: str, max_tokens: int, stream: bool = False):
|
||||
def __call__(self, prompt: str, max_tokens: int, stream: bool = False,
|
||||
stop: Optional[List[str]] = []):
|
||||
if stream:
|
||||
return self.stream(prompt, max_tokens)
|
||||
return self.stream(prompt, max_tokens, stop)
|
||||
else:
|
||||
return self._eval(prompt, max_tokens, False)
|
||||
return self._eval(prompt, max_tokens, False, stop)
|
||||
|
||||
def _eval(self, prompt: str, max_tokens: int, match_str: bool):
|
||||
def _eval(self, prompt: str, max_tokens: int, match_str: bool,
|
||||
stop: Optional[List[str]] = []):
|
||||
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||
created: int = int(time.time())
|
||||
if prompt.endswith("</s>") or max_tokens < 1:
|
||||
return prompt
|
||||
return {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": created,
|
||||
"model": self.model_path,
|
||||
"choices": [
|
||||
{
|
||||
"text": prompt,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "length",
|
||||
}
|
||||
],
|
||||
"usage":
|
||||
{
|
||||
# TODO: need tokenize
|
||||
"prompt_tokens": None,
|
||||
"completion_tokens": None,
|
||||
"total_tokens": None,
|
||||
}
|
||||
}
|
||||
# use `buf` to store prompt and generated string,
|
||||
# assume the average length of words is less than 20 bytes
|
||||
buf = bytes((len(prompt) + max_tokens) * 20)
|
||||
|
|
@ -112,18 +140,85 @@ class Bloom:
|
|||
prompt=bytes(prompt, encoding='utf-8'),
|
||||
buf=buf)
|
||||
s = str(buf, encoding='utf-8').rstrip("\x00")
|
||||
return s
|
||||
|
||||
def stream(self, prompt: str, max_tokens: int):
|
||||
text = s.split(prompt)[1]
|
||||
split_text = text
|
||||
if stop != []:
|
||||
for stop_word in stop:
|
||||
split_text = split_text.split(stop_word)[0]
|
||||
if split_text != text:
|
||||
finish_reason = "stop"
|
||||
else:
|
||||
finish_reason = None
|
||||
return {"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": created,
|
||||
"model": self.model_path,
|
||||
"choices": [
|
||||
{
|
||||
"text": prompt + split_text,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage":
|
||||
{
|
||||
# TODO: need tokenize
|
||||
"prompt_tokens": None,
|
||||
"completion_tokens": None,
|
||||
"total_tokens": None,
|
||||
}
|
||||
}
|
||||
|
||||
def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = []):
|
||||
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||
created: int = int(time.time())
|
||||
if prompt.endswith("</s>") or max_tokens < 1:
|
||||
yield prompt
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": created,
|
||||
"model": self.model_path,
|
||||
"choices": [
|
||||
{
|
||||
"text": prompt,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "length",
|
||||
}
|
||||
],
|
||||
"usage":
|
||||
{
|
||||
# TODO: need tokenize
|
||||
"prompt_tokens": None
|
||||
}
|
||||
}
|
||||
else:
|
||||
for i in range(max_tokens):
|
||||
if prompt.endswith("</s>"):
|
||||
break
|
||||
else:
|
||||
prompt = self._eval(prompt, 1, i != 0)
|
||||
yield prompt
|
||||
prompt = self._eval(prompt, 1, i != 0, stop)['choices'][0]['text']
|
||||
yield {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": created,
|
||||
"model": self.model_path,
|
||||
"choices": [
|
||||
{
|
||||
"text": prompt,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
"usage":
|
||||
{
|
||||
# TODO: need tokenize
|
||||
"prompt_tokens": None
|
||||
}
|
||||
}
|
||||
|
||||
def free(self):
|
||||
bloom_free(self.ctx)
|
||||
|
|
|
|||
Loading…
Reference in a new issue