diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama.py b/python/llm/src/ipex_llm/transformers/npu_models/llama.py index e06d61e1..e665133b 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama.py @@ -34,6 +34,7 @@ from typing import Optional, Tuple, List, Union +import math import torch from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast @@ -230,14 +231,14 @@ def llama_attention_forward( is_causal=self.is_causal and causal_mask is None and q_len > 1, ) else: - # second+ token - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - is_causal=self.is_causal and causal_mask is None and q_len > 1, - ) + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if causal_mask is not None: + attn_weights = attn_weights + causal_mask + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(value_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous()