LLM: fix truncation logic of past_key_values in chatglm multi turn chat (#10007)
* Avoid frequently truncating past_key_values when its length is larger than required.
This commit is contained in:
		
							parent
							
								
									1eaaace2dc
								
							
						
					
					
						commit
						e5ae6f2c13
					
				
					 1 changed files with 6 additions and 11 deletions
				
			
		| 
						 | 
					@ -142,7 +142,9 @@ def chatglm3_stream_chat(model, tokenizer):
 | 
				
			||||||
    current_length = 0
 | 
					    current_length = 0
 | 
				
			||||||
    # https://github.com/THUDM/ChatGLM3/issues/274#issuecomment-1810160305
 | 
					    # https://github.com/THUDM/ChatGLM3/issues/274#issuecomment-1810160305
 | 
				
			||||||
    stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(["<|user|>", "<|observation|>"], tokenizer)])
 | 
					    stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(["<|user|>", "<|observation|>"], tokenizer)])
 | 
				
			||||||
    max_past_length = 2048
 | 
					    # you could change this according to your memory requirement
 | 
				
			||||||
 | 
					    max_past_length = 512
 | 
				
			||||||
 | 
					    block_length = 512
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    while True:
 | 
					    while True:
 | 
				
			||||||
        user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
 | 
					        user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
 | 
				
			||||||
| 
						 | 
					@ -158,6 +160,9 @@ def chatglm3_stream_chat(model, tokenizer):
 | 
				
			||||||
            {user_input}
 | 
					            {user_input}
 | 
				
			||||||
            <|assistant|>
 | 
					            <|assistant|>
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        if past_key_values is not None and past_key_values[0][0].shape[0] > max_past_length + block_length:
 | 
				
			||||||
 | 
					            # To avoid out of memory, only keep recent key_values of max_past_length
 | 
				
			||||||
 | 
					            past_key_values = [(k[-max_past_length:, :, :, :], v[-max_past_length:, :, :, :]) for k, v in past_key_values]
 | 
				
			||||||
        for response, chat_history, past_key_values in model.stream_chat(tokenizer, prompt,
 | 
					        for response, chat_history, past_key_values in model.stream_chat(tokenizer, prompt,
 | 
				
			||||||
                                                                         history=chat_history,
 | 
					                                                                         history=chat_history,
 | 
				
			||||||
                                                                         stopping_criteria=stopping_criteria,
 | 
					                                                                         stopping_criteria=stopping_criteria,
 | 
				
			||||||
| 
						 | 
					@ -165,16 +170,6 @@ def chatglm3_stream_chat(model, tokenizer):
 | 
				
			||||||
                                                                         return_past_key_values=True):
 | 
					                                                                         return_past_key_values=True):
 | 
				
			||||||
            print(response[current_length:], end="", flush=True)
 | 
					            print(response[current_length:], end="", flush=True)
 | 
				
			||||||
            current_length = len(response)
 | 
					            current_length = len(response)
 | 
				
			||||||
            if past_key_values[0][0].shape[0] > max_past_length:
 | 
					 | 
				
			||||||
                # To avoid out of memory, only keep recent key_values
 | 
					 | 
				
			||||||
                new_values_list = []
 | 
					 | 
				
			||||||
                for i in range(len(past_key_values)):
 | 
					 | 
				
			||||||
                    new_value = []
 | 
					 | 
				
			||||||
                    for val in past_key_values[i]:
 | 
					 | 
				
			||||||
                        new_v = val[-max_past_length:]
 | 
					 | 
				
			||||||
                        new_value.append(new_v)
 | 
					 | 
				
			||||||
                    new_values_list.append(tuple(new_value))
 | 
					 | 
				
			||||||
                past_key_values = tuple(new_values_list)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
@torch.no_grad()
 | 
					@torch.no_grad()
 | 
				
			||||||
def qwen_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=[]):
 | 
					def qwen_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=[]):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue