LLM: fix truncation logic of past_key_values in chatglm multi turn chat (#10007)
* Avoid frequently truncating past_key_values when its length is larger than required.
This commit is contained in:
parent
1eaaace2dc
commit
e5ae6f2c13
1 changed files with 6 additions and 11 deletions
|
|
@ -142,7 +142,9 @@ def chatglm3_stream_chat(model, tokenizer):
|
|||
current_length = 0
|
||||
# https://github.com/THUDM/ChatGLM3/issues/274#issuecomment-1810160305
|
||||
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(["<|user|>", "<|observation|>"], tokenizer)])
|
||||
max_past_length = 2048
|
||||
# you could change this according to your memory requirement
|
||||
max_past_length = 512
|
||||
block_length = 512
|
||||
|
||||
while True:
|
||||
user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
|
||||
|
|
@ -158,6 +160,9 @@ def chatglm3_stream_chat(model, tokenizer):
|
|||
{user_input}
|
||||
<|assistant|>
|
||||
"""
|
||||
if past_key_values is not None and past_key_values[0][0].shape[0] > max_past_length + block_length:
|
||||
# To avoid out of memory, only keep recent key_values of max_past_length
|
||||
past_key_values = [(k[-max_past_length:, :, :, :], v[-max_past_length:, :, :, :]) for k, v in past_key_values]
|
||||
for response, chat_history, past_key_values in model.stream_chat(tokenizer, prompt,
|
||||
history=chat_history,
|
||||
stopping_criteria=stopping_criteria,
|
||||
|
|
@ -165,16 +170,6 @@ def chatglm3_stream_chat(model, tokenizer):
|
|||
return_past_key_values=True):
|
||||
print(response[current_length:], end="", flush=True)
|
||||
current_length = len(response)
|
||||
if past_key_values[0][0].shape[0] > max_past_length:
|
||||
# To avoid out of memory, only keep recent key_values
|
||||
new_values_list = []
|
||||
for i in range(len(past_key_values)):
|
||||
new_value = []
|
||||
for val in past_key_values[i]:
|
||||
new_v = val[-max_past_length:]
|
||||
new_value.append(new_v)
|
||||
new_values_list.append(tuple(new_value))
|
||||
past_key_values = tuple(new_values_list)
|
||||
|
||||
@torch.no_grad()
|
||||
def qwen_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=[]):
|
||||
|
|
|
|||
Loading…
Reference in a new issue