LLM: fix bugs during supporting bloom in langchain (#8362)

This commit is contained in:
binbin Deng 2023-06-20 13:30:37 +08:00 committed by GitHub
parent 066f53232d
commit 5f4f399ca7

View file

@ -46,7 +46,7 @@
# only search the first bigdl package and end up finding only one sub-package. # only search the first bigdl package and end up finding only one sub-package.
from .bloom_cpp import bloom_load, bloom_free, bloom_run from .bloom_cpp import bloom_load, bloom_free, bloom_run
from .bloom_cpp import bloom_tokenize, bloom_detokenize, bloom_forward, bloom_eval from .bloom_cpp import bloom_tokenize, bloom_detokenize, bloom_forward, bloom_eval, bloom_embed
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.ggml.model.generation import GenerationMixin from bigdl.llm.ggml.model.generation import GenerationMixin
from typing import List, Optional, Generator, Sequence, Union from typing import List, Optional, Generator, Sequence, Union
@ -127,7 +127,7 @@ class Bloom(GenerationMixin):
# TODO: Some parameters are temporarily not supported # TODO: Some parameters are temporarily not supported
unsupported_arg = {'n_parts': -1, 'n_gpu_layers': 0, 'f16_kv': True, 'logits_all': False, unsupported_arg = {'n_parts': -1, 'n_gpu_layers': 0, 'f16_kv': True, 'logits_all': False,
'vocab_only': False, 'use_mmap': True, 'use_mlock': False, 'vocab_only': False, 'use_mmap': True, 'use_mlock': False,
'embedding': False, 'last_n_tokens_size': 64, 'lora_base': None, 'last_n_tokens_size': 64, 'lora_base': None,
'lora_path': None, 'verbose': True} 'lora_path': None, 'verbose': True}
for arg in unsupported_arg.keys(): for arg in unsupported_arg.keys():
invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}" invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}"
@ -156,44 +156,50 @@ class Bloom(GenerationMixin):
): ):
# TODO: Some parameters are temporarily not supported # TODO: Some parameters are temporarily not supported
# Unsupported parameters are checked in `_supported_call` # Unsupported parameters are checked in `_supported_call`
return self._supported_call(prompt, max_tokens, stream, stop, return self._supported_call(prompt, max_tokens, stream, stop, echo, model,
suffix, temperature, top_p, logprobs, echo, frequency_penalty, suffix, temperature, top_p, logprobs, frequency_penalty,
presence_penalty, repeat_penalty, top_k, tfs_z, mirostat_mode, presence_penalty, repeat_penalty, top_k, tfs_z, mirostat_mode,
mirostat_tau, mirostat_eta, model) mirostat_tau, mirostat_eta)
def _supported_call(self, prompt: str, max_tokens: int, stream: bool = False, def _supported_call(self, prompt: str, max_tokens: int, stream: bool = False,
stop: Optional[List[str]] = [], *args): stop: Optional[List[str]] = [], echo: bool = False,
model: Optional[str] = None, *args):
# Check unsupporeted parameters # Check unsupporeted parameters
unsupported_arg = ['suffix', 'temperature', 'top_p', 'logprobs', 'echo', unsupported_arg = ['suffix', 'temperature', 'top_p', 'logprobs',
'frequency_penalty', 'presence_penalty', 'repeat_penalty', 'top_k', 'frequency_penalty', 'presence_penalty', 'repeat_penalty', 'top_k',
'tfs_z', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'model'] 'tfs_z', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'model']
defult_value = {'suffix': None, 'temperature': 0.8, 'top_p': 0.95, 'logprobs': None, defult_value = {'suffix': None, 'temperature': 0.8, 'top_p': 0.95, 'logprobs': None,
'echo': False, 'frequency_penalty': 0.0, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0, 'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
'mirostat_tau': 5.0, 'mirostat_eta': 0.1, 'model': None} 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
for index in range(len(args)): for index in range(len(args)):
invalidInputError(args[index] == defult_value[unsupported_arg[index]], invalidInputError(args[index] == defult_value[unsupported_arg[index]],
f"The parameter {unsupported_arg[index]} is temporarily " f"The parameter {unsupported_arg[index]} is temporarily "
"unsupported, please use the default value.") "unsupported, please use the default value.")
if stream: if stream:
return self.stream(prompt, max_tokens, stop) return self.stream(prompt, max_tokens, stop, echo, model)
else: else:
return self._eval(prompt, max_tokens, False, stop) return self._eval(prompt, max_tokens, False, stop, echo, model)
def _eval(self, prompt: str, max_tokens: int, match_str: bool, def _eval(self, prompt: str, max_tokens: int, match_str: bool,
stop: Optional[List[str]] = []): stop: Optional[List[str]] = [], echo: bool = False, model: Optional[str] = None):
completion_id: str = f"cmpl-{str(uuid.uuid4())}" completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time()) created: int = int(time.time())
if model is None:
model_name = self.model_path
else:
model_name = model
prompt_len = len(self.tokenize(prompt))
if prompt.endswith("</s>") or max_tokens < 1: if prompt.endswith("</s>") or max_tokens < 1:
return { return {
"id": completion_id, "id": completion_id,
"object": "text_completion", "object": "text_completion",
"created": created, "created": created,
"model": self.model_path, "model": model_name,
"choices": [ "choices": [
{ {
"text": prompt, "text": prompt if echo else "",
"index": 0, "index": 0,
"logprobs": None, "logprobs": None,
"finish_reason": "length", "finish_reason": "length",
@ -201,15 +207,14 @@ class Bloom(GenerationMixin):
], ],
"usage": "usage":
{ {
# TODO: need tokenize "prompt_tokens": prompt_len,
"prompt_tokens": None, "completion_tokens": 0,
"completion_tokens": None, "total_tokens": prompt_len,
"total_tokens": None,
} }
} }
# use `buf` to store prompt and generated string, # use `buf` to store prompt and generated string,
# assume the average length of words is less than 20 bytes # assume the average length of words is less than 20 bytes
buf = bytes((len(prompt) + max_tokens) * 20) buf = bytes((prompt_len + max_tokens) * 20)
ret = bloom_run(ctx=self.ctx, ret = bloom_run(ctx=self.ctx,
seed=self.seed, seed=self.seed,
n_threads=self.n_threads, n_threads=self.n_threads,
@ -229,13 +234,14 @@ class Bloom(GenerationMixin):
finish_reason = "stop" finish_reason = "stop"
else: else:
finish_reason = None finish_reason = None
completion_len = len(self.tokenize(split_text))
return {"id": completion_id, return {"id": completion_id,
"object": "text_completion", "object": "text_completion",
"created": created, "created": created,
"model": self.model_path, "model": model_name,
"choices": [ "choices": [
{ {
"text": prompt + split_text, "text": prompt + split_text if echo else split_text,
"index": 0, "index": 0,
"logprobs": None, "logprobs": None,
"finish_reason": finish_reason, "finish_reason": finish_reason,
@ -243,25 +249,31 @@ class Bloom(GenerationMixin):
], ],
"usage": "usage":
{ {
# TODO: need tokenize "prompt_tokens": prompt_len,
"prompt_tokens": None, "completion_tokens": completion_len,
"completion_tokens": None, "total_tokens": prompt_len + completion_len,
"total_tokens": None,
} }
} }
def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = []): def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = [],
echo: bool = False, model: Optional[str] = None):
completion_id: str = f"cmpl-{str(uuid.uuid4())}" completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time()) created: int = int(time.time())
if model is None:
model_name = self.model_path
else:
model_name = model
prompt_tokens: List[int] = self.tokenize(prompt.encode("utf-8"))
prompt_len = len(prompt_tokens)
if prompt.endswith("</s>") or max_tokens < 1: if prompt.endswith("</s>") or max_tokens < 1:
yield { yield {
"id": completion_id, "id": completion_id,
"object": "text_completion", "object": "text_completion",
"created": created, "created": created,
"model": self.model_path, "model": model_name,
"choices": [ "choices": [
{ {
"text": prompt, "text": prompt if echo else "",
"index": 0, "index": 0,
"logprobs": None, "logprobs": None,
"finish_reason": "length", "finish_reason": "length",
@ -269,24 +281,25 @@ class Bloom(GenerationMixin):
], ],
"usage": "usage":
{ {
# TODO: need tokenize "prompt_tokens": prompt_len
"prompt_tokens": None
} }
} }
else: else:
for i in range(max_tokens): for i in range(max_tokens):
if prompt.endswith("</s>"): token = self.forward(prompt_tokens)
prompt_tokens.append(token)
text = self.detokenize([token]).decode("utf-8", errors="ignore")
if text.endswith("</s>"):
print('\n')
break break
else:
prompt = self._eval(prompt, 1, i != 0, stop)['choices'][0]['text']
yield { yield {
"id": completion_id, "id": completion_id,
"object": "text_completion", "object": "text_completion",
"created": created, "created": created,
"model": self.model_path, "model": model_name,
"choices": [ "choices": [
{ {
"text": prompt, "text": text,
"index": 0, "index": 0,
"logprobs": None, "logprobs": None,
"finish_reason": None, "finish_reason": None,
@ -294,8 +307,7 @@ class Bloom(GenerationMixin):
], ],
"usage": "usage":
{ {
# TODO: need tokenize "prompt_tokens": prompt_len
"prompt_tokens": None
} }
} }
@ -403,9 +415,11 @@ class Bloom(GenerationMixin):
if tokens_or_none is not None: if tokens_or_none is not None:
tokens.extend(tokens_or_none) tokens.extend(tokens_or_none)
def embed(self, prompt: Union[str, bytes]) -> List[float]: def embed(self, input: str) -> List[float]:
"""Only used for langchain""" """Only used for langchain"""
input_ids = self.tokenize(prompt) invalidInputError(self.embedding,
"Bloom model must be created with embedding=True to call this method.")
input_ids = self.tokenize(input)
return bloom_embed(ctx=self.ctx, return bloom_embed(ctx=self.ctx,
input_ids=input_ids, input_ids=input_ids,
seed=self.seed, seed=self.seed,