diff --git a/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm.py b/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm.py index a6891705..794fc2c2 100644 --- a/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm.py +++ b/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm.py @@ -220,7 +220,7 @@ class ChatGLM(GenerationMixin): } n_past = 0 - output_tokens = [] + output_tokens = input_tokens for i in range(max_tokens): token = self.forward(input_ids=input_tokens, n_past=n_past, @@ -234,7 +234,7 @@ class ChatGLM(GenerationMixin): break text = self.detokenize(output_tokens) - split_text = text + split_text = text[len(prompt):] if stop != []: for stop_word in stop: split_text = split_text.split(stop_word)[0] @@ -294,7 +294,8 @@ class ChatGLM(GenerationMixin): } else: n_past = 0 - output_tokens = [] + output_tokens = input_tokens + history_text = prompt for i in range(max_tokens): token = self.forward(input_ids=input_tokens, n_past=n_past, @@ -307,7 +308,9 @@ class ChatGLM(GenerationMixin): if token == self.eos_token(): print('\n') break - text = self.detokenize(token) + text = self.detokenize(output_tokens) + text = text[len(history_text):] + history_text += text yield { "id": completion_id, "object": "text_completion",