LLM: fix bugs during supporting bloom in langchain (#8362)
This commit is contained in:
parent
066f53232d
commit
5f4f399ca7
1 changed files with 66 additions and 52 deletions
|
|
@ -46,7 +46,7 @@
|
||||||
# only search the first bigdl package and end up finding only one sub-package.
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
from .bloom_cpp import bloom_load, bloom_free, bloom_run
|
from .bloom_cpp import bloom_load, bloom_free, bloom_run
|
||||||
from .bloom_cpp import bloom_tokenize, bloom_detokenize, bloom_forward, bloom_eval
|
from .bloom_cpp import bloom_tokenize, bloom_detokenize, bloom_forward, bloom_eval, bloom_embed
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.ggml.model.generation import GenerationMixin
|
from bigdl.llm.ggml.model.generation import GenerationMixin
|
||||||
from typing import List, Optional, Generator, Sequence, Union
|
from typing import List, Optional, Generator, Sequence, Union
|
||||||
|
|
@ -127,7 +127,7 @@ class Bloom(GenerationMixin):
|
||||||
# TODO: Some parameters are temporarily not supported
|
# TODO: Some parameters are temporarily not supported
|
||||||
unsupported_arg = {'n_parts': -1, 'n_gpu_layers': 0, 'f16_kv': True, 'logits_all': False,
|
unsupported_arg = {'n_parts': -1, 'n_gpu_layers': 0, 'f16_kv': True, 'logits_all': False,
|
||||||
'vocab_only': False, 'use_mmap': True, 'use_mlock': False,
|
'vocab_only': False, 'use_mmap': True, 'use_mlock': False,
|
||||||
'embedding': False, 'last_n_tokens_size': 64, 'lora_base': None,
|
'last_n_tokens_size': 64, 'lora_base': None,
|
||||||
'lora_path': None, 'verbose': True}
|
'lora_path': None, 'verbose': True}
|
||||||
for arg in unsupported_arg.keys():
|
for arg in unsupported_arg.keys():
|
||||||
invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}"
|
invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}"
|
||||||
|
|
@ -156,44 +156,50 @@ class Bloom(GenerationMixin):
|
||||||
):
|
):
|
||||||
# TODO: Some parameters are temporarily not supported
|
# TODO: Some parameters are temporarily not supported
|
||||||
# Unsupported parameters are checked in `_supported_call`
|
# Unsupported parameters are checked in `_supported_call`
|
||||||
return self._supported_call(prompt, max_tokens, stream, stop,
|
return self._supported_call(prompt, max_tokens, stream, stop, echo, model,
|
||||||
suffix, temperature, top_p, logprobs, echo, frequency_penalty,
|
suffix, temperature, top_p, logprobs, frequency_penalty,
|
||||||
presence_penalty, repeat_penalty, top_k, tfs_z, mirostat_mode,
|
presence_penalty, repeat_penalty, top_k, tfs_z, mirostat_mode,
|
||||||
mirostat_tau, mirostat_eta, model)
|
mirostat_tau, mirostat_eta)
|
||||||
|
|
||||||
def _supported_call(self, prompt: str, max_tokens: int, stream: bool = False,
|
def _supported_call(self, prompt: str, max_tokens: int, stream: bool = False,
|
||||||
stop: Optional[List[str]] = [], *args):
|
stop: Optional[List[str]] = [], echo: bool = False,
|
||||||
|
model: Optional[str] = None, *args):
|
||||||
# Check unsupporeted parameters
|
# Check unsupporeted parameters
|
||||||
unsupported_arg = ['suffix', 'temperature', 'top_p', 'logprobs', 'echo',
|
unsupported_arg = ['suffix', 'temperature', 'top_p', 'logprobs',
|
||||||
'frequency_penalty', 'presence_penalty', 'repeat_penalty', 'top_k',
|
'frequency_penalty', 'presence_penalty', 'repeat_penalty', 'top_k',
|
||||||
'tfs_z', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'model']
|
'tfs_z', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'model']
|
||||||
defult_value = {'suffix': None, 'temperature': 0.8, 'top_p': 0.95, 'logprobs': None,
|
defult_value = {'suffix': None, 'temperature': 0.8, 'top_p': 0.95, 'logprobs': None,
|
||||||
'echo': False, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
|
'frequency_penalty': 0.0, 'presence_penalty': 0.0,
|
||||||
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
|
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
|
||||||
'mirostat_tau': 5.0, 'mirostat_eta': 0.1, 'model': None}
|
'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
||||||
for index in range(len(args)):
|
for index in range(len(args)):
|
||||||
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
|
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
|
||||||
f"The parameter {unsupported_arg[index]} is temporarily "
|
f"The parameter {unsupported_arg[index]} is temporarily "
|
||||||
"unsupported, please use the default value.")
|
"unsupported, please use the default value.")
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self.stream(prompt, max_tokens, stop)
|
return self.stream(prompt, max_tokens, stop, echo, model)
|
||||||
else:
|
else:
|
||||||
return self._eval(prompt, max_tokens, False, stop)
|
return self._eval(prompt, max_tokens, False, stop, echo, model)
|
||||||
|
|
||||||
def _eval(self, prompt: str, max_tokens: int, match_str: bool,
|
def _eval(self, prompt: str, max_tokens: int, match_str: bool,
|
||||||
stop: Optional[List[str]] = []):
|
stop: Optional[List[str]] = [], echo: bool = False, model: Optional[str] = None):
|
||||||
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||||
created: int = int(time.time())
|
created: int = int(time.time())
|
||||||
|
if model is None:
|
||||||
|
model_name = self.model_path
|
||||||
|
else:
|
||||||
|
model_name = model
|
||||||
|
prompt_len = len(self.tokenize(prompt))
|
||||||
if prompt.endswith("</s>") or max_tokens < 1:
|
if prompt.endswith("</s>") or max_tokens < 1:
|
||||||
return {
|
return {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"created": created,
|
"created": created,
|
||||||
"model": self.model_path,
|
"model": model_name,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": prompt,
|
"text": prompt if echo else "",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": None,
|
||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
|
|
@ -201,15 +207,14 @@ class Bloom(GenerationMixin):
|
||||||
],
|
],
|
||||||
"usage":
|
"usage":
|
||||||
{
|
{
|
||||||
# TODO: need tokenize
|
"prompt_tokens": prompt_len,
|
||||||
"prompt_tokens": None,
|
"completion_tokens": 0,
|
||||||
"completion_tokens": None,
|
"total_tokens": prompt_len,
|
||||||
"total_tokens": None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
# use `buf` to store prompt and generated string,
|
# use `buf` to store prompt and generated string,
|
||||||
# assume the average length of words is less than 20 bytes
|
# assume the average length of words is less than 20 bytes
|
||||||
buf = bytes((len(prompt) + max_tokens) * 20)
|
buf = bytes((prompt_len + max_tokens) * 20)
|
||||||
ret = bloom_run(ctx=self.ctx,
|
ret = bloom_run(ctx=self.ctx,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
n_threads=self.n_threads,
|
n_threads=self.n_threads,
|
||||||
|
|
@ -229,13 +234,14 @@ class Bloom(GenerationMixin):
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
else:
|
else:
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
|
completion_len = len(self.tokenize(split_text))
|
||||||
return {"id": completion_id,
|
return {"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"created": created,
|
"created": created,
|
||||||
"model": self.model_path,
|
"model": model_name,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": prompt + split_text,
|
"text": prompt + split_text if echo else split_text,
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": None,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
|
|
@ -243,25 +249,31 @@ class Bloom(GenerationMixin):
|
||||||
],
|
],
|
||||||
"usage":
|
"usage":
|
||||||
{
|
{
|
||||||
# TODO: need tokenize
|
"prompt_tokens": prompt_len,
|
||||||
"prompt_tokens": None,
|
"completion_tokens": completion_len,
|
||||||
"completion_tokens": None,
|
"total_tokens": prompt_len + completion_len,
|
||||||
"total_tokens": None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = []):
|
def stream(self, prompt: str, max_tokens: int, stop: Optional[List[str]] = [],
|
||||||
|
echo: bool = False, model: Optional[str] = None):
|
||||||
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||||
created: int = int(time.time())
|
created: int = int(time.time())
|
||||||
|
if model is None:
|
||||||
|
model_name = self.model_path
|
||||||
|
else:
|
||||||
|
model_name = model
|
||||||
|
prompt_tokens: List[int] = self.tokenize(prompt.encode("utf-8"))
|
||||||
|
prompt_len = len(prompt_tokens)
|
||||||
if prompt.endswith("</s>") or max_tokens < 1:
|
if prompt.endswith("</s>") or max_tokens < 1:
|
||||||
yield {
|
yield {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"created": created,
|
"created": created,
|
||||||
"model": self.model_path,
|
"model": model_name,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": prompt,
|
"text": prompt if echo else "",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": None,
|
||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
|
|
@ -269,24 +281,25 @@ class Bloom(GenerationMixin):
|
||||||
],
|
],
|
||||||
"usage":
|
"usage":
|
||||||
{
|
{
|
||||||
# TODO: need tokenize
|
"prompt_tokens": prompt_len
|
||||||
"prompt_tokens": None
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
for i in range(max_tokens):
|
for i in range(max_tokens):
|
||||||
if prompt.endswith("</s>"):
|
token = self.forward(prompt_tokens)
|
||||||
|
prompt_tokens.append(token)
|
||||||
|
text = self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||||
|
if text.endswith("</s>"):
|
||||||
|
print('\n')
|
||||||
break
|
break
|
||||||
else:
|
|
||||||
prompt = self._eval(prompt, 1, i != 0, stop)['choices'][0]['text']
|
|
||||||
yield {
|
yield {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"created": created,
|
"created": created,
|
||||||
"model": self.model_path,
|
"model": model_name,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": prompt,
|
"text": text,
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": None,
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
|
|
@ -294,8 +307,7 @@ class Bloom(GenerationMixin):
|
||||||
],
|
],
|
||||||
"usage":
|
"usage":
|
||||||
{
|
{
|
||||||
# TODO: need tokenize
|
"prompt_tokens": prompt_len
|
||||||
"prompt_tokens": None
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -403,9 +415,11 @@ class Bloom(GenerationMixin):
|
||||||
if tokens_or_none is not None:
|
if tokens_or_none is not None:
|
||||||
tokens.extend(tokens_or_none)
|
tokens.extend(tokens_or_none)
|
||||||
|
|
||||||
def embed(self, prompt: Union[str, bytes]) -> List[float]:
|
def embed(self, input: str) -> List[float]:
|
||||||
"""Only used for langchain"""
|
"""Only used for langchain"""
|
||||||
input_ids = self.tokenize(prompt)
|
invalidInputError(self.embedding,
|
||||||
|
"Bloom model must be created with embedding=True to call this method.")
|
||||||
|
input_ids = self.tokenize(input)
|
||||||
return bloom_embed(ctx=self.ctx,
|
return bloom_embed(ctx=self.ctx,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue