[LLM] Support chatglm cache (#8745)

This commit is contained in:
Yishuo Wang 2023-08-14 15:10:46 +08:00 committed by GitHub
parent faaccb64a2
commit 77844125f2
2 changed files with 30 additions and 49 deletions

View file

@ -46,8 +46,8 @@
# only search the first bigdl package and end up finding only one sub-package.
from .chatglm_cpp import chatglm_load, chatglm_tokenize, chatglm_detokenize, chatglm_eval, \
chatglm_eos_token
from .chatglm_cpp import chatglm_load, chatglm_tokenize, chatglm_detokenize, \
chatglm_forward, chatglm_eos_token
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.ggml.model.generation import GenerationMixin
from typing import List, Optional, Generator, Sequence, Union
@ -219,21 +219,16 @@ class ChatGLM(GenerationMixin):
}
}
n_past = 0
output_tokens = input_tokens
for i in range(max_tokens):
token = self.forward(input_ids=input_tokens,
n_past=n_past,
top_k=top_k,
top_p=top_p,
temperature=temperature)
output_tokens.append(token)
n_past += len(input_tokens)
input_tokens = [token]
input_tokens.append(token)
if token == self.eos_token():
break
text = self.detokenize(output_tokens)
text = self.detokenize(input_tokens)
split_text = text[len(prompt):]
split_text.rstrip('<EFBFBD>') # remove partial emoji
if stop != []:
@ -243,7 +238,7 @@ class ChatGLM(GenerationMixin):
finish_reason = "stop"
else:
finish_reason = None
completion_len = n_past - prompt_len
completion_len = len(input_tokens) - prompt_len
return {
"id": completion_id,
"object": "text_completion",
@ -288,28 +283,22 @@ class ChatGLM(GenerationMixin):
"finish_reason": "length",
}
],
"usage":
{
"prompt_tokens": prompt_len
"usage": {
"prompt_tokens": prompt_len
}
}
else:
n_past = 0
output_tokens = input_tokens
history_text = prompt
for i in range(max_tokens):
token = self.forward(input_ids=input_tokens,
n_past=n_past,
top_k=top_k,
top_p=top_p,
temperature=temperature)
output_tokens.append(token)
n_past += len(input_tokens)
input_tokens = [token]
input_tokens.append(token)
if token == self.eos_token():
print('\n')
break
text = self.detokenize(output_tokens)
text = self.detokenize(input_tokens)
if text.endswith('<EFBFBD>'):
# generated new token is part of an emoji
# (some emoji consists of multiple tokens)
@ -330,17 +319,16 @@ class ChatGLM(GenerationMixin):
"finish_reason": None,
}
],
"usage":
{
"prompt_tokens": prompt_len
"usage": {
"prompt_tokens": prompt_len
}
}
def _tokenize(self, text: bytes, *args) -> List[int]:
def _tokenize(self, text: str, *args) -> List[int]:
"""Tokenize a string.
Args:
text: The utf-8 encoded string to tokenize.
text: The string to tokenize.
Raises:
RuntimeError: If the tokenization failed.
@ -366,18 +354,16 @@ class ChatGLM(GenerationMixin):
def forward(self,
input_ids: List[int],
n_past: int,
do_sample: bool = True,
top_k: int = 0,
top_p: float = 0.7,
temperature: float = 0.95,) -> int:
return chatglm_eval(ctx=self.ctx,
input_ids=input_ids,
n_past=n_past,
do_sample=do_sample,
top_k=top_k,
top_p=top_p,
temperature=temperature)
return chatglm_forward(ctx=self.ctx,
input_ids=input_ids,
do_sample=do_sample,
top_k=top_k,
top_p=top_p,
temperature=temperature)
def eos_token(self) -> int:
return chatglm_eos_token(self.ctx)
@ -431,15 +417,12 @@ class ChatGLM(GenerationMixin):
"unsupported, please use the default value.")
invalidInputError(self.ctx is not None, "The attribute `ctx` of `ChatGLM` object is None.")
n_past = 0
while True:
token = self.forward(input_ids=tokens,
n_past=n_past,
top_k=top_k,
top_p=top_p,
temperature=temp)
n_past += len(tokens)
tokens_or_none = yield token
tokens = [token]
tokens.append(token)
if tokens_or_none is not None:
tokens.extend(tokens_or_none)

View file

@ -40,7 +40,7 @@ def chatglm_load(path: str,
path = str(Path(path))
pipeline = Pipeline(path, use_mmap)
config = GenerationConfig(
max_context_length=n_ctx,
max_length=n_ctx,
num_threads=n_threads,
)
return ChatGLMContext(pipeline, config)
@ -54,20 +54,18 @@ def chatglm_detokenize(ctx: ChatGLMContext, input_ids: List[int]) -> str:
return ctx.pipeline.tokenizer.decode(input_ids)
def chatglm_eval(ctx: ChatGLMContext,
input_ids: List[int],
n_past: int,
do_sample: bool = True,
top_k: int = 0,
top_p: float = 0.7,
temperature: float = 0.95,
) -> int:
def chatglm_forward(ctx: ChatGLMContext,
input_ids: List[int],
do_sample: bool = True,
top_k: int = 0,
top_p: float = 0.7,
temperature: float = 0.95,
) -> int:
ctx.config.do_sample = do_sample
ctx.config.top_k = top_k
ctx.config.top_p = top_p
ctx.temperature = temperature
return ctx.pipeline.model.generate_next_token(input_ids, ctx.config, n_past,
ctx.config.max_context_length)
ctx.config.temperature = temperature
return ctx.pipeline.forward(input_ids, ctx.config)
def chatglm_eos_token(ctx: ChatGLMContext):