diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 26239eff..df1deff4 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -729,7 +729,7 @@ def llama_attention_forward_4_36( def native_sdp(query, key, value, attention_mask, bsz, q_len, kv_seq_len, head_dim, num_heads): - attn_weights = torch.matmul(query, + attn_weights = torch.matmul(query.to(key.dtype), key.transpose(2, 3)) / math.sqrt(head_dim) attn_weights_size = (bsz, num_heads, q_len, kv_seq_len)