diff --git a/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm.py b/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm.py index 4ee9d10f..8ef53c63 100644 --- a/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm.py +++ b/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm.py @@ -46,8 +46,8 @@ # only search the first bigdl package and end up finding only one sub-package. -from .chatglm_cpp import chatglm_load, chatglm_tokenize, chatglm_detokenize, chatglm_eval, \ - chatglm_eos_token +from .chatglm_cpp import chatglm_load, chatglm_tokenize, chatglm_detokenize, \ + chatglm_forward, chatglm_eos_token from bigdl.llm.utils.common import invalidInputError from bigdl.llm.ggml.model.generation import GenerationMixin from typing import List, Optional, Generator, Sequence, Union @@ -219,21 +219,16 @@ class ChatGLM(GenerationMixin): } } - n_past = 0 - output_tokens = input_tokens for i in range(max_tokens): token = self.forward(input_ids=input_tokens, - n_past=n_past, top_k=top_k, top_p=top_p, temperature=temperature) - output_tokens.append(token) - n_past += len(input_tokens) - input_tokens = [token] + input_tokens.append(token) if token == self.eos_token(): break - text = self.detokenize(output_tokens) + text = self.detokenize(input_tokens) split_text = text[len(prompt):] split_text.rstrip('�') # remove partial emoji if stop != []: @@ -243,7 +238,7 @@ class ChatGLM(GenerationMixin): finish_reason = "stop" else: finish_reason = None - completion_len = n_past - prompt_len + completion_len = len(input_tokens) - prompt_len return { "id": completion_id, "object": "text_completion", @@ -288,28 +283,22 @@ class ChatGLM(GenerationMixin): "finish_reason": "length", } ], - "usage": - { - "prompt_tokens": prompt_len + "usage": { + "prompt_tokens": prompt_len } } else: - n_past = 0 - output_tokens = input_tokens history_text = prompt for i in range(max_tokens): token = self.forward(input_ids=input_tokens, - n_past=n_past, top_k=top_k, top_p=top_p, temperature=temperature) - output_tokens.append(token) - n_past += len(input_tokens) - input_tokens = [token] + input_tokens.append(token) if token == self.eos_token(): print('\n') break - text = self.detokenize(output_tokens) + text = self.detokenize(input_tokens) if text.endswith('�'): # generated new token is part of an emoji # (some emoji consists of multiple tokens) @@ -330,17 +319,16 @@ class ChatGLM(GenerationMixin): "finish_reason": None, } ], - "usage": - { - "prompt_tokens": prompt_len + "usage": { + "prompt_tokens": prompt_len } } - def _tokenize(self, text: bytes, *args) -> List[int]: + def _tokenize(self, text: str, *args) -> List[int]: """Tokenize a string. Args: - text: The utf-8 encoded string to tokenize. + text: The string to tokenize. Raises: RuntimeError: If the tokenization failed. @@ -366,18 +354,16 @@ class ChatGLM(GenerationMixin): def forward(self, input_ids: List[int], - n_past: int, do_sample: bool = True, top_k: int = 0, top_p: float = 0.7, temperature: float = 0.95,) -> int: - return chatglm_eval(ctx=self.ctx, - input_ids=input_ids, - n_past=n_past, - do_sample=do_sample, - top_k=top_k, - top_p=top_p, - temperature=temperature) + return chatglm_forward(ctx=self.ctx, + input_ids=input_ids, + do_sample=do_sample, + top_k=top_k, + top_p=top_p, + temperature=temperature) def eos_token(self) -> int: return chatglm_eos_token(self.ctx) @@ -431,15 +417,12 @@ class ChatGLM(GenerationMixin): "unsupported, please use the default value.") invalidInputError(self.ctx is not None, "The attribute `ctx` of `ChatGLM` object is None.") - n_past = 0 while True: token = self.forward(input_ids=tokens, - n_past=n_past, top_k=top_k, top_p=top_p, temperature=temp) - n_past += len(tokens) tokens_or_none = yield token - tokens = [token] + tokens.append(token) if tokens_or_none is not None: tokens.extend(tokens_or_none) diff --git a/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm_cpp.py b/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm_cpp.py index 12b1a45e..47513e54 100644 --- a/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm_cpp.py +++ b/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm_cpp.py @@ -40,7 +40,7 @@ def chatglm_load(path: str, path = str(Path(path)) pipeline = Pipeline(path, use_mmap) config = GenerationConfig( - max_context_length=n_ctx, + max_length=n_ctx, num_threads=n_threads, ) return ChatGLMContext(pipeline, config) @@ -54,20 +54,18 @@ def chatglm_detokenize(ctx: ChatGLMContext, input_ids: List[int]) -> str: return ctx.pipeline.tokenizer.decode(input_ids) -def chatglm_eval(ctx: ChatGLMContext, - input_ids: List[int], - n_past: int, - do_sample: bool = True, - top_k: int = 0, - top_p: float = 0.7, - temperature: float = 0.95, - ) -> int: +def chatglm_forward(ctx: ChatGLMContext, + input_ids: List[int], + do_sample: bool = True, + top_k: int = 0, + top_p: float = 0.7, + temperature: float = 0.95, + ) -> int: ctx.config.do_sample = do_sample ctx.config.top_k = top_k ctx.config.top_p = top_p - ctx.temperature = temperature - return ctx.pipeline.model.generate_next_token(input_ids, ctx.config, n_past, - ctx.config.max_context_length) + ctx.config.temperature = temperature + return ctx.pipeline.forward(input_ids, ctx.config) def chatglm_eos_token(ctx: ChatGLMContext):