quick fix qwen2 fp8 kv cache (#10135)
This commit is contained in:
		
							parent
							
								
									39d90839aa
								
							
						
					
					
						commit
						4d33aac7f9
					
				
					 1 changed files with 2 additions and 0 deletions
				
			
		| 
						 | 
					@ -167,6 +167,8 @@ def qwen2_attention_forward_quantized(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if q_len != 1:
 | 
					    if q_len != 1:
 | 
				
			||||||
        key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
 | 
					        key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
 | 
				
			||||||
 | 
					        key = repeat_kv(key, self.num_key_value_groups)
 | 
				
			||||||
 | 
					        value = repeat_kv(value, self.num_key_value_groups)
 | 
				
			||||||
        attn_weights = torch.matmul(query_states, key.transpose(2, 3))
 | 
					        attn_weights = torch.matmul(query_states, key.transpose(2, 3))
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        import linear_q4_0
 | 
					        import linear_q4_0
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue