LLM: fix bugs during supporting bloom in langchain (#8362)

This commit is contained in:
binbin Deng 2023-06-20 13:30:37 +08:00 committed by GitHub
parent 066f53232d
commit 5f4f399ca7

View file

@ -46,7 +46,7 @@
# only search the first bigdl package and end up finding only one sub-package.
from .bloom_cpp import bloom_load, bloom_free, bloom_run
from .bloom_cpp import bloom_tokenize, bloom_detokenize, bloom_forward, bloom_eval
from .bloom_cpp import bloom_tokenize, bloom_detokenize, bloom_forward, bloom_eval, bloom_embed
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.ggml.model.generation import GenerationMixin
from typing import List, Optional, Generator, Sequence, Union
@ -127,7 +127,7 @@ class Bloom(GenerationMixin):
# TODO: Some parameters are temporarily not supported
unsupported_arg = {'n_parts': -1, 'n_gpu_layers': 0, 'f16_kv': True, 'logits_all': False,
'vocab_only': False, 'use_mmap': True, 'use_mlock': False,
'embedding': False, 'last_n_tokens_size': 64, 'lora_base': None,
'last_n_tokens_size': 64, 'lora_base': None,
'lora_path': None, 'verbose': True}
for arg in unsupported_arg.keys():
invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}"
@ -156,44 +156,50 @@ class Bloom(GenerationMixin):
):
# TODO: Some parameters are temporarily not supported
# Unsupported parameters are checked in `_supported_call`
return self._supported_call(prompt, max_tokens, stream, stop,
suffix, temperature, top_p, logprobs, echo, frequency_penalty,
return self._supported_call(prompt, max_tokens, stream, stop, echo, model,
suffix, temperature, top_p, logprobs, frequency_penalty,
presence_penalty, repeat_penalty, top_k, tfs_z, mirostat_mode,
mirostat_tau, mirostat_eta, model)
mirostat_tau, mirostat_eta)
def _supported_call(self, prompt: str, max_tokens: int, stream: bool = False,
stop: Optional[List[str]] = [], *args):
stop: Optional[List[str]] = [], echo: bool = False,
model: Optional[str] = None, *args):
# Check unsupporeted parameters
unsupported_arg = ['suffix', 'temperature', 'top_p', 'logprobs', 'echo',
unsupported_arg = ['suffix', 'temperature', 'top_p', 'logprobs',
'frequency_penalty', 'presence_penalty', 'repeat_penalty', 'top_k',
'tfs_z', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'model']
defult_value = {'suffix': None, 'temperature': 0.8, 'top_p': 0.95, 'logprobs': None,
'echo': False, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
'frequency_penalty': 0.0, 'presence_penalty': 0.0,
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
'mirostat_tau': 5.0, 'mirostat_eta': 0.1, 'model': None}
'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
for index in range(len(args)):
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
f"The parameter {unsupported_arg[index]} is temporarily "
"unsupported, please use the default value.")
if stream:
return self.stream(prompt, max_tokens, stop)
return self.stream(prompt, max_tokens, stop, echo, model)
else:
return self._eval(prompt, max_tokens, False, stop)
return self._eval(prompt, max_tokens, False, stop, echo, model)
def _eval(self, prompt: str, max_tokens: int, match_str: bool,
stop: Optional[List[str]] = []):
stop: Optional[List[str]] = [], echo: bool = False, model: Optional[str] = None):
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
if model is None:
model_name = self.model_path
else:
model_name = model
prompt_len = len(self.tokenize(prompt))
if prompt.endswith("</s>") or max_tokens < 1:
return {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": self.model_path,
"model": model_name,
"choices": [
{
"text": prompt,
"text": prompt if echo else "",
"index": 0,
"logprobs": None,
"finish_reason": "length",
@ -201,15 +207,14 @@ class Bloom(GenerationMixin):
],
"usage":
{
# TODO: need tokenize
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
"prompt_tokens": prompt_len,
"completion_tokens": 0,
"total_tokens": prompt_len,
}
}
# use `buf` to store prompt and generated string,
# assume the average length of words is less than 20 bytes
buf = bytes((len(prompt) + max_tokens) * 20)
buf = bytes((prompt_len + max_tokens) * 20)
ret = bloom_run(ctx=self.ctx,
seed=self.seed,
n_threads=self.n_threads,
@ -229,13 +234,14 @@ class Bloom(GenerationMixin):
finish_reason = "stop"
else:
finish_reason = None
completion_len = len(self.tokenize(split_text))
return {"id": completion_id,
"object": "text_completion",
"created": created,
"model": self.model_path,
"model": model_name,
"choices": [
{
"text": prompt + split_text,
"text": prompt + split_text if echo else split_text,
"index": 0,
"logprobs": None,
"finish_reason": finish_reason,
@ -243,25 +249,31 @@ class Bloom(GenerationMixin):
],
"usage":
{
# TODO: need tokenize
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
"prompt_tokens": prompt_len,
"completion_tokens": completion_len,
"total_tokens": prompt_len + completion_len,
}
}
def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = []):
def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = [],
echo: bool = False, model: Optional[str] = None):
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
if model is None:
model_name = self.model_path
else:
model_name = model
prompt_tokens: List[int] = self.tokenize(prompt.encode("utf-8"))
prompt_len = len(prompt_tokens)
if prompt.endswith("</s>") or max_tokens < 1:
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": self.model_path,
"model": model_name,
"choices": [
{
"text": prompt,
"text": prompt if echo else "",
"index": 0,
"logprobs": None,
"finish_reason": "length",
@ -269,24 +281,25 @@ class Bloom(GenerationMixin):
],
"usage":
{
# TODO: need tokenize
"prompt_tokens": None
"prompt_tokens": prompt_len
}
}
else:
for i in range(max_tokens):
if prompt.endswith("</s>"):
token = self.forward(prompt_tokens)
prompt_tokens.append(token)
text = self.detokenize([token]).decode("utf-8", errors="ignore")
if text.endswith("</s>"):
print('\n')
break
else:
prompt = self._eval(prompt, 1, i != 0, stop)['choices'][0]['text']
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": self.model_path,
"model": model_name,
"choices": [
{
"text": prompt,
"text": text,
"index": 0,
"logprobs": None,
"finish_reason": None,
@ -294,8 +307,7 @@ class Bloom(GenerationMixin):
],
"usage":
{
# TODO: need tokenize
"prompt_tokens": None
"prompt_tokens": prompt_len
}
}
@ -403,9 +415,11 @@ class Bloom(GenerationMixin):
if tokens_or_none is not None:
tokens.extend(tokens_or_none)
def embed(self, prompt: Union[str, bytes]) -> List[float]:
def embed(self, input: str) -> List[float]:
"""Only used for langchain"""
input_ids = self.tokenize(prompt)
invalidInputError(self.embedding,
"Bloom model must be created with embedding=True to call this method.")
input_ids = self.tokenize(input)
return bloom_embed(ctx=self.ctx,
input_ids=input_ids,
seed=self.seed,