diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 8c1be44a..e542c6c5 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1346,26 +1346,10 @@ def llama_attention_forward_4_36_original( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) # otherwise, use native attention - if query_states.device.type == "xpu": - dev_name = torch.xpu.get_device_name(query_states.device.index) - else: - dev_name = "CPU" - if not output_attentions and not dev_name.startswith("Intel(R) Data Center GPU Max"): - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that - # does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, - ) - else: - attn_output, attn_weights = native_sdp(query_states, key_states, value_states, - attention_mask, - bsz, q_len, kv_seq_len, - self.head_dim, self.num_heads, output_attentions) + attn_output, attn_weights = native_sdp(query_states, key_states, value_states, + attention_mask, + bsz, q_len, kv_seq_len, + self.head_dim, self.num_heads, output_attentions) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) if attn_output.size() != attn_output_size: @@ -1789,6 +1773,9 @@ def llama_model_forward_4_36_internal( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + # IPEX-LLM modifications: + # Disable sdpa for CPU + self._use_sdpa = False if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) \