LLM: fix truncation logic of past_key_values in chatglm multi turn chat (#10007)
* Avoid frequently truncating past_key_values when its length is larger than required.
This commit is contained in:
		
							parent
							
								
									1eaaace2dc
								
							
						
					
					
						commit
						e5ae6f2c13
					
				
					 1 changed files with 6 additions and 11 deletions
				
			
		| 
						 | 
				
			
			@ -142,7 +142,9 @@ def chatglm3_stream_chat(model, tokenizer):
 | 
			
		|||
    current_length = 0
 | 
			
		||||
    # https://github.com/THUDM/ChatGLM3/issues/274#issuecomment-1810160305
 | 
			
		||||
    stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(["<|user|>", "<|observation|>"], tokenizer)])
 | 
			
		||||
    max_past_length = 2048
 | 
			
		||||
    # you could change this according to your memory requirement
 | 
			
		||||
    max_past_length = 512
 | 
			
		||||
    block_length = 512
 | 
			
		||||
 | 
			
		||||
    while True:
 | 
			
		||||
        user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
 | 
			
		||||
| 
						 | 
				
			
			@ -158,6 +160,9 @@ def chatglm3_stream_chat(model, tokenizer):
 | 
			
		|||
            {user_input}
 | 
			
		||||
            <|assistant|>
 | 
			
		||||
        """
 | 
			
		||||
        if past_key_values is not None and past_key_values[0][0].shape[0] > max_past_length + block_length:
 | 
			
		||||
            # To avoid out of memory, only keep recent key_values of max_past_length
 | 
			
		||||
            past_key_values = [(k[-max_past_length:, :, :, :], v[-max_past_length:, :, :, :]) for k, v in past_key_values]
 | 
			
		||||
        for response, chat_history, past_key_values in model.stream_chat(tokenizer, prompt,
 | 
			
		||||
                                                                         history=chat_history,
 | 
			
		||||
                                                                         stopping_criteria=stopping_criteria,
 | 
			
		||||
| 
						 | 
				
			
			@ -165,16 +170,6 @@ def chatglm3_stream_chat(model, tokenizer):
 | 
			
		|||
                                                                         return_past_key_values=True):
 | 
			
		||||
            print(response[current_length:], end="", flush=True)
 | 
			
		||||
            current_length = len(response)
 | 
			
		||||
            if past_key_values[0][0].shape[0] > max_past_length:
 | 
			
		||||
                # To avoid out of memory, only keep recent key_values
 | 
			
		||||
                new_values_list = []
 | 
			
		||||
                for i in range(len(past_key_values)):
 | 
			
		||||
                    new_value = []
 | 
			
		||||
                    for val in past_key_values[i]:
 | 
			
		||||
                        new_v = val[-max_past_length:]
 | 
			
		||||
                        new_value.append(new_v)
 | 
			
		||||
                    new_values_list.append(tuple(new_value))
 | 
			
		||||
                past_key_values = tuple(new_values_list)
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def qwen_stream_chat(model, tokenizer, kv_cache=None, max_gen_len=512, stop_words=[]):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue