From e5ae6f2c13109b0aae7e4b4d7318dbb3d415653d Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Fri, 26 Jan 2024 16:56:02 +0800 Subject: [PATCH] LLM: fix truncation logic of past_key_values in chatglm multi turn chat (#10007) * Avoid frequently truncating past_key_values when its length is larger than required. --- python/llm/portable-zip/chat.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/python/llm/portable-zip/chat.py b/python/llm/portable-zip/chat.py index 94e454d1..b5cc48fc 100644 --- a/python/llm/portable-zip/chat.py +++ b/python/llm/portable-zip/chat.py @@ -142,7 +142,9 @@ def chatglm3_stream_chat(model, tokenizer): current_length = 0 # https://github.com/THUDM/ChatGLM3/issues/274#issuecomment-1810160305 stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(["<|user|>", "<|observation|>"], tokenizer)]) - max_past_length = 2048 + # you could change this according to your memory requirement + max_past_length = 512 + block_length = 512 while True: user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET) @@ -158,6 +160,9 @@ def chatglm3_stream_chat(model, tokenizer): {user_input} <|assistant|> """ + if past_key_values is not None and past_key_values[0][0].shape[0] > max_past_length + block_length: + # To avoid out of memory, only keep recent key_values of max_past_length + past_key_values = [(k[-max_past_length:, :, :, :], v[-max_past_length:, :, :, :]) for k, v in past_key_values] for response, chat_history, past_key_values in model.stream_chat(tokenizer, prompt, history=chat_history, stopping_criteria=stopping_criteria, @@ -165,16 +170,6 @@ def chatglm3_stream_chat(model, tokenizer): return_past_key_values=True): print(response[current_length:], end="", flush=True) current_length = len(response) - if past_key_values[0][0].shape[0] > max_past_length: - # To avoid out of memory, only keep recent key_values - new_values_list = [] - for i in range(len(past_key_values)): - new_value = [] - for val in past_key_values[i]: - new_v = val[-max_past_length:] - new_value.append(new_v) - new_values_list.append(tuple(new_value)) - past_key_values = tuple(new_values_list) @torch.no_grad() def qwen_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=[]):