quick fix qwen2 fp8 kv cache (#10135)

This commit is contained in:
Yishuo Wang 2024-02-08 17:04:59 +08:00 committed by GitHub
parent 39d90839aa
commit 4d33aac7f9

View file

@ -167,6 +167,8 @@ def qwen2_attention_forward_quantized(
if q_len != 1:
key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
key = repeat_kv(key, self.num_key_value_groups)
value = repeat_kv(value, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key.transpose(2, 3))
else:
import linear_q4_0