[LLM] Support chatglm cache (#8745)
This commit is contained in:
parent
faaccb64a2
commit
77844125f2
2 changed files with 30 additions and 49 deletions
|
|
@ -46,8 +46,8 @@
|
|||
# only search the first bigdl package and end up finding only one sub-package.
|
||||
|
||||
|
||||
from .chatglm_cpp import chatglm_load, chatglm_tokenize, chatglm_detokenize, chatglm_eval, \
|
||||
chatglm_eos_token
|
||||
from .chatglm_cpp import chatglm_load, chatglm_tokenize, chatglm_detokenize, \
|
||||
chatglm_forward, chatglm_eos_token
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.ggml.model.generation import GenerationMixin
|
||||
from typing import List, Optional, Generator, Sequence, Union
|
||||
|
|
@ -219,21 +219,16 @@ class ChatGLM(GenerationMixin):
|
|||
}
|
||||
}
|
||||
|
||||
n_past = 0
|
||||
output_tokens = input_tokens
|
||||
for i in range(max_tokens):
|
||||
token = self.forward(input_ids=input_tokens,
|
||||
n_past=n_past,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature)
|
||||
output_tokens.append(token)
|
||||
n_past += len(input_tokens)
|
||||
input_tokens = [token]
|
||||
input_tokens.append(token)
|
||||
if token == self.eos_token():
|
||||
break
|
||||
|
||||
text = self.detokenize(output_tokens)
|
||||
text = self.detokenize(input_tokens)
|
||||
split_text = text[len(prompt):]
|
||||
split_text.rstrip('<EFBFBD>') # remove partial emoji
|
||||
if stop != []:
|
||||
|
|
@ -243,7 +238,7 @@ class ChatGLM(GenerationMixin):
|
|||
finish_reason = "stop"
|
||||
else:
|
||||
finish_reason = None
|
||||
completion_len = n_past - prompt_len
|
||||
completion_len = len(input_tokens) - prompt_len
|
||||
return {
|
||||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
|
|
@ -288,28 +283,22 @@ class ChatGLM(GenerationMixin):
|
|||
"finish_reason": "length",
|
||||
}
|
||||
],
|
||||
"usage":
|
||||
{
|
||||
"prompt_tokens": prompt_len
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_len
|
||||
}
|
||||
}
|
||||
else:
|
||||
n_past = 0
|
||||
output_tokens = input_tokens
|
||||
history_text = prompt
|
||||
for i in range(max_tokens):
|
||||
token = self.forward(input_ids=input_tokens,
|
||||
n_past=n_past,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature)
|
||||
output_tokens.append(token)
|
||||
n_past += len(input_tokens)
|
||||
input_tokens = [token]
|
||||
input_tokens.append(token)
|
||||
if token == self.eos_token():
|
||||
print('\n')
|
||||
break
|
||||
text = self.detokenize(output_tokens)
|
||||
text = self.detokenize(input_tokens)
|
||||
if text.endswith('<EFBFBD>'):
|
||||
# generated new token is part of an emoji
|
||||
# (some emoji consists of multiple tokens)
|
||||
|
|
@ -330,17 +319,16 @@ class ChatGLM(GenerationMixin):
|
|||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
"usage":
|
||||
{
|
||||
"prompt_tokens": prompt_len
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_len
|
||||
}
|
||||
}
|
||||
|
||||
def _tokenize(self, text: bytes, *args) -> List[int]:
|
||||
def _tokenize(self, text: str, *args) -> List[int]:
|
||||
"""Tokenize a string.
|
||||
|
||||
Args:
|
||||
text: The utf-8 encoded string to tokenize.
|
||||
text: The string to tokenize.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the tokenization failed.
|
||||
|
|
@ -366,18 +354,16 @@ class ChatGLM(GenerationMixin):
|
|||
|
||||
def forward(self,
|
||||
input_ids: List[int],
|
||||
n_past: int,
|
||||
do_sample: bool = True,
|
||||
top_k: int = 0,
|
||||
top_p: float = 0.7,
|
||||
temperature: float = 0.95,) -> int:
|
||||
return chatglm_eval(ctx=self.ctx,
|
||||
input_ids=input_ids,
|
||||
n_past=n_past,
|
||||
do_sample=do_sample,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature)
|
||||
return chatglm_forward(ctx=self.ctx,
|
||||
input_ids=input_ids,
|
||||
do_sample=do_sample,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature)
|
||||
|
||||
def eos_token(self) -> int:
|
||||
return chatglm_eos_token(self.ctx)
|
||||
|
|
@ -431,15 +417,12 @@ class ChatGLM(GenerationMixin):
|
|||
"unsupported, please use the default value.")
|
||||
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `ChatGLM` object is None.")
|
||||
n_past = 0
|
||||
while True:
|
||||
token = self.forward(input_ids=tokens,
|
||||
n_past=n_past,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temp)
|
||||
n_past += len(tokens)
|
||||
tokens_or_none = yield token
|
||||
tokens = [token]
|
||||
tokens.append(token)
|
||||
if tokens_or_none is not None:
|
||||
tokens.extend(tokens_or_none)
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def chatglm_load(path: str,
|
|||
path = str(Path(path))
|
||||
pipeline = Pipeline(path, use_mmap)
|
||||
config = GenerationConfig(
|
||||
max_context_length=n_ctx,
|
||||
max_length=n_ctx,
|
||||
num_threads=n_threads,
|
||||
)
|
||||
return ChatGLMContext(pipeline, config)
|
||||
|
|
@ -54,20 +54,18 @@ def chatglm_detokenize(ctx: ChatGLMContext, input_ids: List[int]) -> str:
|
|||
return ctx.pipeline.tokenizer.decode(input_ids)
|
||||
|
||||
|
||||
def chatglm_eval(ctx: ChatGLMContext,
|
||||
input_ids: List[int],
|
||||
n_past: int,
|
||||
do_sample: bool = True,
|
||||
top_k: int = 0,
|
||||
top_p: float = 0.7,
|
||||
temperature: float = 0.95,
|
||||
) -> int:
|
||||
def chatglm_forward(ctx: ChatGLMContext,
|
||||
input_ids: List[int],
|
||||
do_sample: bool = True,
|
||||
top_k: int = 0,
|
||||
top_p: float = 0.7,
|
||||
temperature: float = 0.95,
|
||||
) -> int:
|
||||
ctx.config.do_sample = do_sample
|
||||
ctx.config.top_k = top_k
|
||||
ctx.config.top_p = top_p
|
||||
ctx.temperature = temperature
|
||||
return ctx.pipeline.model.generate_next_token(input_ids, ctx.config, n_past,
|
||||
ctx.config.max_context_length)
|
||||
ctx.config.temperature = temperature
|
||||
return ctx.pipeline.forward(input_ids, ctx.config)
|
||||
|
||||
|
||||
def chatglm_eos_token(ctx: ChatGLMContext):
|
||||
|
|
|
|||
Loading…
Reference in a new issue