add sdp causal support in llama (#11705)

This commit is contained in:
Yina Chen 2024-08-02 05:27:40 +03:00 committed by GitHub
parent 736a7ef72e
commit 8d1e0bd2f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1497,6 +1497,13 @@ def llama_attention_forward_4_41_original(
value_states.to(device, dtype=torch.float16),
is_causal=True)
attn_weights = None
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim,
query_states, self.training):
import xe_addons
attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(),
value_states.contiguous(), new_attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import xe_addons
@ -2040,6 +2047,13 @@ def llama_attention_forward_4_38_original(
value_states.to(device, dtype=torch.float16),
is_causal=True)
attn_weights = None
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim,
query_states, self.training):
import xe_addons
attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(),
value_states.contiguous(), new_attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import xe_addons