add sdp causal support in llama (#11705)
This commit is contained in:
parent
736a7ef72e
commit
8d1e0bd2f4
1 changed files with 14 additions and 0 deletions
|
|
@ -1497,6 +1497,13 @@ def llama_attention_forward_4_41_original(
|
|||
value_states.to(device, dtype=torch.float16),
|
||||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim,
|
||||
query_states, self.training):
|
||||
import xe_addons
|
||||
attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(),
|
||||
value_states.contiguous(), new_attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import xe_addons
|
||||
|
|
@ -2040,6 +2047,13 @@ def llama_attention_forward_4_38_original(
|
|||
value_states.to(device, dtype=torch.float16),
|
||||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim,
|
||||
query_states, self.training):
|
||||
import xe_addons
|
||||
attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(),
|
||||
value_states.contiguous(), new_attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import xe_addons
|
||||
|
|
|
|||
Loading…
Reference in a new issue