From 98ea3459e5fc804d8ada9d952cd25106b444de00 Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Fri, 26 Jan 2024 10:59:48 +0800 Subject: [PATCH] LLM : Fix llama draft_model dtype error (#10005) * fix llama draft_model dtype error * updat --- python/llm/src/bigdl/llm/transformers/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 26239eff..df1deff4 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -729,7 +729,7 @@ def llama_attention_forward_4_36( def native_sdp(query, key, value, attention_mask, bsz, q_len, kv_seq_len, head_dim, num_heads): - attn_weights = torch.matmul(query, + attn_weights = torch.matmul(query.to(key.dtype), key.transpose(2, 3)) / math.sqrt(head_dim) attn_weights_size = (bsz, num_heads, q_len, kv_seq_len)