[LLM] Support chatglm cache (#8745)
This commit is contained in:
parent
faaccb64a2
commit
77844125f2
2 changed files with 30 additions and 49 deletions
|
|
@ -46,8 +46,8 @@
|
||||||
# only search the first bigdl package and end up finding only one sub-package.
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
|
|
||||||
from .chatglm_cpp import chatglm_load, chatglm_tokenize, chatglm_detokenize, chatglm_eval, \
|
from .chatglm_cpp import chatglm_load, chatglm_tokenize, chatglm_detokenize, \
|
||||||
chatglm_eos_token
|
chatglm_forward, chatglm_eos_token
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.ggml.model.generation import GenerationMixin
|
from bigdl.llm.ggml.model.generation import GenerationMixin
|
||||||
from typing import List, Optional, Generator, Sequence, Union
|
from typing import List, Optional, Generator, Sequence, Union
|
||||||
|
|
@ -219,21 +219,16 @@ class ChatGLM(GenerationMixin):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
n_past = 0
|
|
||||||
output_tokens = input_tokens
|
|
||||||
for i in range(max_tokens):
|
for i in range(max_tokens):
|
||||||
token = self.forward(input_ids=input_tokens,
|
token = self.forward(input_ids=input_tokens,
|
||||||
n_past=n_past,
|
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temperature=temperature)
|
temperature=temperature)
|
||||||
output_tokens.append(token)
|
input_tokens.append(token)
|
||||||
n_past += len(input_tokens)
|
|
||||||
input_tokens = [token]
|
|
||||||
if token == self.eos_token():
|
if token == self.eos_token():
|
||||||
break
|
break
|
||||||
|
|
||||||
text = self.detokenize(output_tokens)
|
text = self.detokenize(input_tokens)
|
||||||
split_text = text[len(prompt):]
|
split_text = text[len(prompt):]
|
||||||
split_text.rstrip('<EFBFBD>') # remove partial emoji
|
split_text.rstrip('<EFBFBD>') # remove partial emoji
|
||||||
if stop != []:
|
if stop != []:
|
||||||
|
|
@ -243,7 +238,7 @@ class ChatGLM(GenerationMixin):
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
else:
|
else:
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
completion_len = n_past - prompt_len
|
completion_len = len(input_tokens) - prompt_len
|
||||||
return {
|
return {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
|
|
@ -288,28 +283,22 @@ class ChatGLM(GenerationMixin):
|
||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"usage":
|
"usage": {
|
||||||
{
|
"prompt_tokens": prompt_len
|
||||||
"prompt_tokens": prompt_len
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
n_past = 0
|
|
||||||
output_tokens = input_tokens
|
|
||||||
history_text = prompt
|
history_text = prompt
|
||||||
for i in range(max_tokens):
|
for i in range(max_tokens):
|
||||||
token = self.forward(input_ids=input_tokens,
|
token = self.forward(input_ids=input_tokens,
|
||||||
n_past=n_past,
|
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temperature=temperature)
|
temperature=temperature)
|
||||||
output_tokens.append(token)
|
input_tokens.append(token)
|
||||||
n_past += len(input_tokens)
|
|
||||||
input_tokens = [token]
|
|
||||||
if token == self.eos_token():
|
if token == self.eos_token():
|
||||||
print('\n')
|
print('\n')
|
||||||
break
|
break
|
||||||
text = self.detokenize(output_tokens)
|
text = self.detokenize(input_tokens)
|
||||||
if text.endswith('<EFBFBD>'):
|
if text.endswith('<EFBFBD>'):
|
||||||
# generated new token is part of an emoji
|
# generated new token is part of an emoji
|
||||||
# (some emoji consists of multiple tokens)
|
# (some emoji consists of multiple tokens)
|
||||||
|
|
@ -330,17 +319,16 @@ class ChatGLM(GenerationMixin):
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"usage":
|
"usage": {
|
||||||
{
|
"prompt_tokens": prompt_len
|
||||||
"prompt_tokens": prompt_len
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def _tokenize(self, text: bytes, *args) -> List[int]:
|
def _tokenize(self, text: str, *args) -> List[int]:
|
||||||
"""Tokenize a string.
|
"""Tokenize a string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: The utf-8 encoded string to tokenize.
|
text: The string to tokenize.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If the tokenization failed.
|
RuntimeError: If the tokenization failed.
|
||||||
|
|
@ -366,18 +354,16 @@ class ChatGLM(GenerationMixin):
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: List[int],
|
input_ids: List[int],
|
||||||
n_past: int,
|
|
||||||
do_sample: bool = True,
|
do_sample: bool = True,
|
||||||
top_k: int = 0,
|
top_k: int = 0,
|
||||||
top_p: float = 0.7,
|
top_p: float = 0.7,
|
||||||
temperature: float = 0.95,) -> int:
|
temperature: float = 0.95,) -> int:
|
||||||
return chatglm_eval(ctx=self.ctx,
|
return chatglm_forward(ctx=self.ctx,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
n_past=n_past,
|
do_sample=do_sample,
|
||||||
do_sample=do_sample,
|
top_k=top_k,
|
||||||
top_k=top_k,
|
top_p=top_p,
|
||||||
top_p=top_p,
|
temperature=temperature)
|
||||||
temperature=temperature)
|
|
||||||
|
|
||||||
def eos_token(self) -> int:
|
def eos_token(self) -> int:
|
||||||
return chatglm_eos_token(self.ctx)
|
return chatglm_eos_token(self.ctx)
|
||||||
|
|
@ -431,15 +417,12 @@ class ChatGLM(GenerationMixin):
|
||||||
"unsupported, please use the default value.")
|
"unsupported, please use the default value.")
|
||||||
|
|
||||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `ChatGLM` object is None.")
|
invalidInputError(self.ctx is not None, "The attribute `ctx` of `ChatGLM` object is None.")
|
||||||
n_past = 0
|
|
||||||
while True:
|
while True:
|
||||||
token = self.forward(input_ids=tokens,
|
token = self.forward(input_ids=tokens,
|
||||||
n_past=n_past,
|
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temperature=temp)
|
temperature=temp)
|
||||||
n_past += len(tokens)
|
|
||||||
tokens_or_none = yield token
|
tokens_or_none = yield token
|
||||||
tokens = [token]
|
tokens.append(token)
|
||||||
if tokens_or_none is not None:
|
if tokens_or_none is not None:
|
||||||
tokens.extend(tokens_or_none)
|
tokens.extend(tokens_or_none)
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ def chatglm_load(path: str,
|
||||||
path = str(Path(path))
|
path = str(Path(path))
|
||||||
pipeline = Pipeline(path, use_mmap)
|
pipeline = Pipeline(path, use_mmap)
|
||||||
config = GenerationConfig(
|
config = GenerationConfig(
|
||||||
max_context_length=n_ctx,
|
max_length=n_ctx,
|
||||||
num_threads=n_threads,
|
num_threads=n_threads,
|
||||||
)
|
)
|
||||||
return ChatGLMContext(pipeline, config)
|
return ChatGLMContext(pipeline, config)
|
||||||
|
|
@ -54,20 +54,18 @@ def chatglm_detokenize(ctx: ChatGLMContext, input_ids: List[int]) -> str:
|
||||||
return ctx.pipeline.tokenizer.decode(input_ids)
|
return ctx.pipeline.tokenizer.decode(input_ids)
|
||||||
|
|
||||||
|
|
||||||
def chatglm_eval(ctx: ChatGLMContext,
|
def chatglm_forward(ctx: ChatGLMContext,
|
||||||
input_ids: List[int],
|
input_ids: List[int],
|
||||||
n_past: int,
|
do_sample: bool = True,
|
||||||
do_sample: bool = True,
|
top_k: int = 0,
|
||||||
top_k: int = 0,
|
top_p: float = 0.7,
|
||||||
top_p: float = 0.7,
|
temperature: float = 0.95,
|
||||||
temperature: float = 0.95,
|
) -> int:
|
||||||
) -> int:
|
|
||||||
ctx.config.do_sample = do_sample
|
ctx.config.do_sample = do_sample
|
||||||
ctx.config.top_k = top_k
|
ctx.config.top_k = top_k
|
||||||
ctx.config.top_p = top_p
|
ctx.config.top_p = top_p
|
||||||
ctx.temperature = temperature
|
ctx.config.temperature = temperature
|
||||||
return ctx.pipeline.model.generate_next_token(input_ids, ctx.config, n_past,
|
return ctx.pipeline.forward(input_ids, ctx.config)
|
||||||
ctx.config.max_context_length)
|
|
||||||
|
|
||||||
|
|
||||||
def chatglm_eos_token(ctx: ChatGLMContext):
|
def chatglm_eos_token(ctx: ChatGLMContext):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue