quick fix qwen2 fp8 kv cache (#10135)
This commit is contained in:
		
							parent
							
								
									39d90839aa
								
							
						
					
					
						commit
						4d33aac7f9
					
				
					 1 changed files with 2 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -167,6 +167,8 @@ def qwen2_attention_forward_quantized(
 | 
			
		|||
 | 
			
		||||
    if q_len != 1:
 | 
			
		||||
        key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
 | 
			
		||||
        key = repeat_kv(key, self.num_key_value_groups)
 | 
			
		||||
        value = repeat_kv(value, self.num_key_value_groups)
 | 
			
		||||
        attn_weights = torch.matmul(query_states, key.transpose(2, 3))
 | 
			
		||||
    else:
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue