add comment of compress kv in attention forward
This commit is contained in:
parent
ce7de77085
commit
a2be3d7501
1 changed files with 2 additions and 4 deletions
|
|
@ -307,13 +307,11 @@ def baichuan_attention_forward_7b(
|
||||||
is_causal=True).to(hidden_states.dtype)
|
is_causal=True).to(hidden_states.dtype)
|
||||||
elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
|
if use_compresskv:
|
||||||
|
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||||
attention_mask)
|
attention_mask)
|
||||||
elif use_compresskv:
|
|
||||||
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
|
||||||
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
|
||||||
attention_mask)
|
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
||||||
attention_mask)
|
attention_mask)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue