add comment of compress kv in attention forward

This commit is contained in:
Huang, Xinshengzi 2024-08-22 15:11:55 +08:00
parent ce7de77085
commit a2be3d7501

View file

@ -307,13 +307,11 @@ def baichuan_attention_forward_7b(
is_causal=True).to(hidden_states.dtype)
elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if use_compresskv:
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
elif use_compresskv:
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)