[LLM] Support chatglm cache (#8745)

This commit is contained in:
Yishuo Wang 2023-08-14 15:10:46 +08:00 committed by GitHub
parent faaccb64a2
commit 77844125f2
2 changed files with 30 additions and 49 deletions

View file

@ -46,8 +46,8 @@
# only search the first bigdl package and end up finding only one sub-package. # only search the first bigdl package and end up finding only one sub-package.
from .chatglm_cpp import chatglm_load, chatglm_tokenize, chatglm_detokenize, chatglm_eval, \ from .chatglm_cpp import chatglm_load, chatglm_tokenize, chatglm_detokenize, \
chatglm_eos_token chatglm_forward, chatglm_eos_token
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.ggml.model.generation import GenerationMixin from bigdl.llm.ggml.model.generation import GenerationMixin
from typing import List, Optional, Generator, Sequence, Union from typing import List, Optional, Generator, Sequence, Union
@ -219,21 +219,16 @@ class ChatGLM(GenerationMixin):
} }
} }
n_past = 0
output_tokens = input_tokens
for i in range(max_tokens): for i in range(max_tokens):
token = self.forward(input_ids=input_tokens, token = self.forward(input_ids=input_tokens,
n_past=n_past,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
temperature=temperature) temperature=temperature)
output_tokens.append(token) input_tokens.append(token)
n_past += len(input_tokens)
input_tokens = [token]
if token == self.eos_token(): if token == self.eos_token():
break break
text = self.detokenize(output_tokens) text = self.detokenize(input_tokens)
split_text = text[len(prompt):] split_text = text[len(prompt):]
split_text.rstrip('<EFBFBD>') # remove partial emoji split_text.rstrip('<EFBFBD>') # remove partial emoji
if stop != []: if stop != []:
@ -243,7 +238,7 @@ class ChatGLM(GenerationMixin):
finish_reason = "stop" finish_reason = "stop"
else: else:
finish_reason = None finish_reason = None
completion_len = n_past - prompt_len completion_len = len(input_tokens) - prompt_len
return { return {
"id": completion_id, "id": completion_id,
"object": "text_completion", "object": "text_completion",
@ -288,28 +283,22 @@ class ChatGLM(GenerationMixin):
"finish_reason": "length", "finish_reason": "length",
} }
], ],
"usage": "usage": {
{ "prompt_tokens": prompt_len
"prompt_tokens": prompt_len
} }
} }
else: else:
n_past = 0
output_tokens = input_tokens
history_text = prompt history_text = prompt
for i in range(max_tokens): for i in range(max_tokens):
token = self.forward(input_ids=input_tokens, token = self.forward(input_ids=input_tokens,
n_past=n_past,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
temperature=temperature) temperature=temperature)
output_tokens.append(token) input_tokens.append(token)
n_past += len(input_tokens)
input_tokens = [token]
if token == self.eos_token(): if token == self.eos_token():
print('\n') print('\n')
break break
text = self.detokenize(output_tokens) text = self.detokenize(input_tokens)
if text.endswith('<EFBFBD>'): if text.endswith('<EFBFBD>'):
# generated new token is part of an emoji # generated new token is part of an emoji
# (some emoji consists of multiple tokens) # (some emoji consists of multiple tokens)
@ -330,17 +319,16 @@ class ChatGLM(GenerationMixin):
"finish_reason": None, "finish_reason": None,
} }
], ],
"usage": "usage": {
{ "prompt_tokens": prompt_len
"prompt_tokens": prompt_len
} }
} }
def _tokenize(self, text: bytes, *args) -> List[int]: def _tokenize(self, text: str, *args) -> List[int]:
"""Tokenize a string. """Tokenize a string.
Args: Args:
text: The utf-8 encoded string to tokenize. text: The string to tokenize.
Raises: Raises:
RuntimeError: If the tokenization failed. RuntimeError: If the tokenization failed.
@ -366,18 +354,16 @@ class ChatGLM(GenerationMixin):
def forward(self, def forward(self,
input_ids: List[int], input_ids: List[int],
n_past: int,
do_sample: bool = True, do_sample: bool = True,
top_k: int = 0, top_k: int = 0,
top_p: float = 0.7, top_p: float = 0.7,
temperature: float = 0.95,) -> int: temperature: float = 0.95,) -> int:
return chatglm_eval(ctx=self.ctx, return chatglm_forward(ctx=self.ctx,
input_ids=input_ids, input_ids=input_ids,
n_past=n_past, do_sample=do_sample,
do_sample=do_sample, top_k=top_k,
top_k=top_k, top_p=top_p,
top_p=top_p, temperature=temperature)
temperature=temperature)
def eos_token(self) -> int: def eos_token(self) -> int:
return chatglm_eos_token(self.ctx) return chatglm_eos_token(self.ctx)
@ -431,15 +417,12 @@ class ChatGLM(GenerationMixin):
"unsupported, please use the default value.") "unsupported, please use the default value.")
invalidInputError(self.ctx is not None, "The attribute `ctx` of `ChatGLM` object is None.") invalidInputError(self.ctx is not None, "The attribute `ctx` of `ChatGLM` object is None.")
n_past = 0
while True: while True:
token = self.forward(input_ids=tokens, token = self.forward(input_ids=tokens,
n_past=n_past,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
temperature=temp) temperature=temp)
n_past += len(tokens)
tokens_or_none = yield token tokens_or_none = yield token
tokens = [token] tokens.append(token)
if tokens_or_none is not None: if tokens_or_none is not None:
tokens.extend(tokens_or_none) tokens.extend(tokens_or_none)

View file

@ -40,7 +40,7 @@ def chatglm_load(path: str,
path = str(Path(path)) path = str(Path(path))
pipeline = Pipeline(path, use_mmap) pipeline = Pipeline(path, use_mmap)
config = GenerationConfig( config = GenerationConfig(
max_context_length=n_ctx, max_length=n_ctx,
num_threads=n_threads, num_threads=n_threads,
) )
return ChatGLMContext(pipeline, config) return ChatGLMContext(pipeline, config)
@ -54,20 +54,18 @@ def chatglm_detokenize(ctx: ChatGLMContext, input_ids: List[int]) -> str:
return ctx.pipeline.tokenizer.decode(input_ids) return ctx.pipeline.tokenizer.decode(input_ids)
def chatglm_eval(ctx: ChatGLMContext, def chatglm_forward(ctx: ChatGLMContext,
input_ids: List[int], input_ids: List[int],
n_past: int, do_sample: bool = True,
do_sample: bool = True, top_k: int = 0,
top_k: int = 0, top_p: float = 0.7,
top_p: float = 0.7, temperature: float = 0.95,
temperature: float = 0.95, ) -> int:
) -> int:
ctx.config.do_sample = do_sample ctx.config.do_sample = do_sample
ctx.config.top_k = top_k ctx.config.top_k = top_k
ctx.config.top_p = top_p ctx.config.top_p = top_p
ctx.temperature = temperature ctx.config.temperature = temperature
return ctx.pipeline.model.generate_next_token(input_ids, ctx.config, n_past, return ctx.pipeline.forward(input_ids, ctx.config)
ctx.config.max_context_length)
def chatglm_eos_token(ctx: ChatGLMContext): def chatglm_eos_token(ctx: ChatGLMContext):