fix kv cache out of bound (#9827)
This commit is contained in:
		
							parent
							
								
									5df31db773
								
							
						
					
					
						commit
						f919f5792a
					
				
					 2 changed files with 4 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -149,7 +149,7 @@ def llama_attention_forward_4_31(
 | 
			
		|||
        attention_dtype = original_dtype
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
 | 
			
		||||
    qtype = getattr(self.q_proj, "qtype", None)
 | 
			
		||||
    is_q4_0 = qtype == SYM_INT4
 | 
			
		||||
    no_tp = not self.config.pretraining_tp > 1
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -116,10 +116,11 @@ def is_enough_kv_cache_room_4_36(past_key_value, idx):
 | 
			
		|||
        past_key_value.key_cache[idx].size(3)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_enough_kv_cache_room_4_31(past_key_value):
 | 
			
		||||
def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1):
 | 
			
		||||
    # to determinate if is enough kv cache room in transformers between 4.31 and 4.35
 | 
			
		||||
    return past_key_value is not None and \
 | 
			
		||||
        past_key_value[0].stride()[1] > past_key_value[0].size(2) * past_key_value[0].size(3)
 | 
			
		||||
        past_key_value[0].stride()[1] > \
 | 
			
		||||
        (past_key_value[0].size(2) + seq_len - 1) * past_key_value[0].size(3)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_flash_attention(query):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue