diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 85f58f61..fd7ecffd 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1497,6 +1497,13 @@ def llama_attention_forward_4_41_original( value_states.to(device, dtype=torch.float16), is_causal=True) attn_weights = None + elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, + query_states, self.training): + import xe_addons + attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(), + value_states.contiguous(), new_attention_mask) + attn_output = attn_output.view(query_states.shape) + attn_weights = None elif not self.training and not hidden_states.requires_grad and \ use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import xe_addons @@ -2040,6 +2047,13 @@ def llama_attention_forward_4_38_original( value_states.to(device, dtype=torch.float16), is_causal=True) attn_weights = None + elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, + query_states, self.training): + import xe_addons + attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(), + value_states.contiguous(), new_attention_mask) + attn_output = attn_output.view(query_states.shape) + attn_weights = None elif not self.training and not hidden_states.requires_grad and \ use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import xe_addons