fix lookahead with transformers >= 4.36 (#10808)
This commit is contained in:
		
							parent
							
								
									34ff07b689
								
							
						
					
					
						commit
						57edf2033c
					
				
					 1 changed files with 14 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -451,6 +451,20 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False):
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
 | 
			
		||||
    if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache)):
 | 
			
		||||
        if hasattr(past_key_values, "_seen_tokens"):
 | 
			
		||||
            past_key_values._seen_tokens -= new_cache_size
 | 
			
		||||
        else:
 | 
			
		||||
            past_key_values.seen_tokens -= new_cache_size
 | 
			
		||||
 | 
			
		||||
        for i, k in enumerate(past_key_values.key_cache):
 | 
			
		||||
            past_key_values.key_cache[i] = k[:, :, :-new_cache_size, :]
 | 
			
		||||
        for i, v in enumerate(past_key_values.value_cache):
 | 
			
		||||
            past_key_values.value_cache[i] = v[:, :, :-new_cache_size, :]
 | 
			
		||||
 | 
			
		||||
        return past_key_values
 | 
			
		||||
 | 
			
		||||
    if _enable_ipex:
 | 
			
		||||
        cur_len = past_key_values[0][0].size(1)
 | 
			
		||||
        delta = new_cache_size
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue