LLM: add stop words and enhance output for bloom pybinding (#8280)
This commit is contained in:
parent
6990328e5c
commit
f9e2bda04a
1 changed files with 105 additions and 10 deletions
|
|
@ -47,6 +47,9 @@
|
||||||
|
|
||||||
from .bloom_cpp import bloom_load, bloom_free, bloom_run
|
from .bloom_cpp import bloom_load, bloom_free, bloom_run
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
from typing import List, Optional
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
class Bloom:
|
class Bloom:
|
||||||
|
|
@ -81,6 +84,7 @@ class Bloom:
|
||||||
Returns:
|
Returns:
|
||||||
A Bloom instance.
|
A Bloom instance.
|
||||||
"""
|
"""
|
||||||
|
self.model_path = model_path
|
||||||
self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads)
|
self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads)
|
||||||
invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}")
|
invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}")
|
||||||
self.n_ctx = n_ctx
|
self.n_ctx = n_ctx
|
||||||
|
|
@ -91,15 +95,39 @@ class Bloom:
|
||||||
self.last_n_tokens_size = last_n_tokens_size
|
self.last_n_tokens_size = last_n_tokens_size
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
def __call__(self, prompt: str, max_tokens: int, stream: bool = False):
|
def __call__(self, prompt: str, max_tokens: int, stream: bool = False,
|
||||||
|
stop: Optional[List[str]] = []):
|
||||||
if stream:
|
if stream:
|
||||||
return self.stream(prompt, max_tokens)
|
return self.stream(prompt, max_tokens, stop)
|
||||||
else:
|
else:
|
||||||
return self._eval(prompt, max_tokens, False)
|
return self._eval(prompt, max_tokens, False, stop)
|
||||||
|
|
||||||
def _eval(self, prompt: str, max_tokens: int, match_str: bool):
|
def _eval(self, prompt: str, max_tokens: int, match_str: bool,
|
||||||
|
stop: Optional[List[str]] = []):
|
||||||
|
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||||
|
created: int = int(time.time())
|
||||||
if prompt.endswith("</s>") or max_tokens < 1:
|
if prompt.endswith("</s>") or max_tokens < 1:
|
||||||
return prompt
|
return {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": created,
|
||||||
|
"model": self.model_path,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": prompt,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "length",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage":
|
||||||
|
{
|
||||||
|
# TODO: need tokenize
|
||||||
|
"prompt_tokens": None,
|
||||||
|
"completion_tokens": None,
|
||||||
|
"total_tokens": None,
|
||||||
|
}
|
||||||
|
}
|
||||||
# use `buf` to store prompt and generated string,
|
# use `buf` to store prompt and generated string,
|
||||||
# assume the average length of words is less than 20 bytes
|
# assume the average length of words is less than 20 bytes
|
||||||
buf = bytes((len(prompt) + max_tokens) * 20)
|
buf = bytes((len(prompt) + max_tokens) * 20)
|
||||||
|
|
@ -112,18 +140,85 @@ class Bloom:
|
||||||
prompt=bytes(prompt, encoding='utf-8'),
|
prompt=bytes(prompt, encoding='utf-8'),
|
||||||
buf=buf)
|
buf=buf)
|
||||||
s = str(buf, encoding='utf-8').rstrip("\x00")
|
s = str(buf, encoding='utf-8').rstrip("\x00")
|
||||||
return s
|
|
||||||
|
|
||||||
def stream(self, prompt: str, max_tokens: int):
|
text = s.split(prompt)[1]
|
||||||
|
split_text = text
|
||||||
|
if stop != []:
|
||||||
|
for stop_word in stop:
|
||||||
|
split_text = split_text.split(stop_word)[0]
|
||||||
|
if split_text != text:
|
||||||
|
finish_reason = "stop"
|
||||||
|
else:
|
||||||
|
finish_reason = None
|
||||||
|
return {"id": completion_id,
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": created,
|
||||||
|
"model": self.model_path,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": prompt + split_text,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage":
|
||||||
|
{
|
||||||
|
# TODO: need tokenize
|
||||||
|
"prompt_tokens": None,
|
||||||
|
"completion_tokens": None,
|
||||||
|
"total_tokens": None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = []):
|
||||||
|
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||||
|
created: int = int(time.time())
|
||||||
if prompt.endswith("</s>") or max_tokens < 1:
|
if prompt.endswith("</s>") or max_tokens < 1:
|
||||||
yield prompt
|
yield {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": created,
|
||||||
|
"model": self.model_path,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": prompt,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "length",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage":
|
||||||
|
{
|
||||||
|
# TODO: need tokenize
|
||||||
|
"prompt_tokens": None
|
||||||
|
}
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
for i in range(max_tokens):
|
for i in range(max_tokens):
|
||||||
if prompt.endswith("</s>"):
|
if prompt.endswith("</s>"):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
prompt = self._eval(prompt, 1, i != 0)
|
prompt = self._eval(prompt, 1, i != 0, stop)['choices'][0]['text']
|
||||||
yield prompt
|
yield {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": created,
|
||||||
|
"model": self.model_path,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": prompt,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage":
|
||||||
|
{
|
||||||
|
# TODO: need tokenize
|
||||||
|
"prompt_tokens": None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def free(self):
|
def free(self):
|
||||||
bloom_free(self.ctx)
|
bloom_free(self.ctx)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue