LLM: add stop words and enhance output for bloom pybinding (#8280)

This commit is contained in:
binbin Deng 2023-06-08 14:06:06 +08:00 committed by GitHub
parent 6990328e5c
commit f9e2bda04a

View file

@ -47,6 +47,9 @@
from .bloom_cpp import bloom_load, bloom_free, bloom_run
from bigdl.llm.utils.common import invalidInputError
from typing import List, Optional
import time
import uuid
class Bloom:
@ -81,6 +84,7 @@ class Bloom:
Returns:
A Bloom instance.
"""
self.model_path = model_path
self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads)
invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}")
self.n_ctx = n_ctx
@ -91,15 +95,39 @@ class Bloom:
self.last_n_tokens_size = last_n_tokens_size
self.verbose = verbose
def __call__(self, prompt: str, max_tokens: int, stream: bool = False):
def __call__(self, prompt: str, max_tokens: int, stream: bool = False,
stop: Optional[List[str]] = []):
if stream:
return self.stream(prompt, max_tokens)
return self.stream(prompt, max_tokens, stop)
else:
return self._eval(prompt, max_tokens, False)
return self._eval(prompt, max_tokens, False, stop)
def _eval(self, prompt: str, max_tokens: int, match_str: bool):
def _eval(self, prompt: str, max_tokens: int, match_str: bool,
stop: Optional[List[str]] = []):
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
if prompt.endswith("</s>") or max_tokens < 1:
return prompt
return {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": self.model_path,
"choices": [
{
"text": prompt,
"index": 0,
"logprobs": None,
"finish_reason": "length",
}
],
"usage":
{
# TODO: need tokenize
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
}
}
# use `buf` to store prompt and generated string,
# assume the average length of words is less than 20 bytes
buf = bytes((len(prompt) + max_tokens) * 20)
@ -112,18 +140,85 @@ class Bloom:
prompt=bytes(prompt, encoding='utf-8'),
buf=buf)
s = str(buf, encoding='utf-8').rstrip("\x00")
return s
def stream(self, prompt: str, max_tokens: int):
text = s.split(prompt)[1]
split_text = text
if stop != []:
for stop_word in stop:
split_text = split_text.split(stop_word)[0]
if split_text != text:
finish_reason = "stop"
else:
finish_reason = None
return {"id": completion_id,
"object": "text_completion",
"created": created,
"model": self.model_path,
"choices": [
{
"text": prompt + split_text,
"index": 0,
"logprobs": None,
"finish_reason": finish_reason,
}
],
"usage":
{
# TODO: need tokenize
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
}
}
def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = []):
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
if prompt.endswith("</s>") or max_tokens < 1:
yield prompt
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": self.model_path,
"choices": [
{
"text": prompt,
"index": 0,
"logprobs": None,
"finish_reason": "length",
}
],
"usage":
{
# TODO: need tokenize
"prompt_tokens": None
}
}
else:
for i in range(max_tokens):
if prompt.endswith("</s>"):
break
else:
prompt = self._eval(prompt, 1, i != 0)
yield prompt
prompt = self._eval(prompt, 1, i != 0, stop)['choices'][0]['text']
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": self.model_path,
"choices": [
{
"text": prompt,
"index": 0,
"logprobs": None,
"finish_reason": None,
}
],
"usage":
{
# TODO: need tokenize
"prompt_tokens": None
}
}
def free(self):
bloom_free(self.ctx)