From 31ea2f9a9f3e5af3b798316feb9881f3c4de48a2 Mon Sep 17 00:00:00 2001 From: Guancheng Fu <110874468+gc-fu@users.noreply.github.com> Date: Thu, 18 Apr 2024 11:07:27 +0800 Subject: [PATCH] Fix wrong output for Llama models on CPU (#10742) --- .../src/ipex_llm/transformers/models/llama.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index b9308bc6..cd4054b1 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1335,10 +1335,22 @@ def llama_attention_forward_4_36_original( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) # otherwise, use native attention - attn_output, attn_weights = native_sdp(query_states, key_states, value_states, - attention_mask, - bsz, q_len, kv_seq_len, - self.head_dim, self.num_heads, output_attentions) + if not output_attentions: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that + # does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + else: + attn_output, attn_weights = native_sdp(query_states, key_states, value_states, + attention_mask, + bsz, q_len, kv_seq_len, + self.head_dim, self.num_heads, output_attentions) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) if attn_output.size() != attn_output_size: