LLM: fix truncation logic of past_key_values in chatglm multi turn chat (#10007)

* Avoid frequently truncating past_key_values  when its length is larger than required.
This commit is contained in:
binbin Deng 2024-01-26 16:56:02 +08:00 committed by GitHub
parent 1eaaace2dc
commit e5ae6f2c13

View file

@ -142,7 +142,9 @@ def chatglm3_stream_chat(model, tokenizer):
current_length = 0
# https://github.com/THUDM/ChatGLM3/issues/274#issuecomment-1810160305
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(["<|user|>", "<|observation|>"], tokenizer)])
max_past_length = 2048
# you could change this according to your memory requirement
max_past_length = 512
block_length = 512
while True:
user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
@ -158,6 +160,9 @@ def chatglm3_stream_chat(model, tokenizer):
{user_input}
<|assistant|>
"""
if past_key_values is not None and past_key_values[0][0].shape[0] > max_past_length + block_length:
# To avoid out of memory, only keep recent key_values of max_past_length
past_key_values = [(k[-max_past_length:, :, :, :], v[-max_past_length:, :, :, :]) for k, v in past_key_values]
for response, chat_history, past_key_values in model.stream_chat(tokenizer, prompt,
history=chat_history,
stopping_criteria=stopping_criteria,
@ -165,16 +170,6 @@ def chatglm3_stream_chat(model, tokenizer):
return_past_key_values=True):
print(response[current_length:], end="", flush=True)
current_length = len(response)
if past_key_values[0][0].shape[0] > max_past_length:
# To avoid out of memory, only keep recent key_values
new_values_list = []
for i in range(len(past_key_values)):
new_value = []
for val in past_key_values[i]:
new_v = val[-max_past_length:]
new_value.append(new_v)
new_values_list.append(tuple(new_value))
past_key_values = tuple(new_values_list)
@torch.no_grad()
def qwen_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=[]):