From 77efcf7b1d3481d3978c2a7ffd39132c698783d0 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Fri, 11 Aug 2023 14:51:50 +0800 Subject: [PATCH] LLM: fix ChatGLM2 native int4 stream output (#8733) --- .../llm/src/bigdl/llm/ggml/model/chatglm/chatglm.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm.py b/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm.py index a6891705..794fc2c2 100644 --- a/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm.py +++ b/python/llm/src/bigdl/llm/ggml/model/chatglm/chatglm.py @@ -220,7 +220,7 @@ class ChatGLM(GenerationMixin): } n_past = 0 - output_tokens = [] + output_tokens = input_tokens for i in range(max_tokens): token = self.forward(input_ids=input_tokens, n_past=n_past, @@ -234,7 +234,7 @@ class ChatGLM(GenerationMixin): break text = self.detokenize(output_tokens) - split_text = text + split_text = text[len(prompt):] if stop != []: for stop_word in stop: split_text = split_text.split(stop_word)[0] @@ -294,7 +294,8 @@ class ChatGLM(GenerationMixin): } else: n_past = 0 - output_tokens = [] + output_tokens = input_tokens + history_text = prompt for i in range(max_tokens): token = self.forward(input_ids=input_tokens, n_past=n_past, @@ -307,7 +308,9 @@ class ChatGLM(GenerationMixin): if token == self.eos_token(): print('\n') break - text = self.detokenize(token) + text = self.detokenize(output_tokens) + text = text[len(history_text):] + history_text += text yield { "id": completion_id, "object": "text_completion",