quick fix qwen2 fp8 kv cache (#10135)
This commit is contained in:
parent
39d90839aa
commit
4d33aac7f9
1 changed files with 2 additions and 0 deletions
|
|
@ -167,6 +167,8 @@ def qwen2_attention_forward_quantized(
|
||||||
|
|
||||||
if q_len != 1:
|
if q_len != 1:
|
||||||
key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
|
key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
|
||||||
|
key = repeat_kv(key, self.num_key_value_groups)
|
||||||
|
value = repeat_kv(value, self.num_key_value_groups)
|
||||||
attn_weights = torch.matmul(query_states, key.transpose(2, 3))
|
attn_weights = torch.matmul(query_states, key.transpose(2, 3))
|
||||||
else:
|
else:
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue