diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama.py b/python/llm/src/ipex_llm/transformers/npu_models/llama.py index 004ecc6a..8a830d9e 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama.py @@ -219,13 +219,25 @@ def llama_attention_forward( else: causal_mask = None - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - is_causal=self.is_causal and attention_mask is None and q_len > 1, - ) + if query_states.size(2) == key_states.size(2): + # first token + from intel_npu_acceleration_library.functional import scaled_dot_product_attention + attn_output = scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + is_causal=self.is_causal and causal_mask is None and q_len > 1, + ) + else: + # second+ token + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + is_causal=self.is_causal and causal_mask is None and q_len > 1, + ) attn_output = attn_output.transpose(1, 2).contiguous()