From 38c05be1c05b98065391d86b95bed05104980f15 Mon Sep 17 00:00:00 2001 From: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com> Date: Thu, 4 Jan 2024 15:34:42 +0800 Subject: [PATCH] [LLM] Fix dtype mismatch in Baichuan2-13b (#9834) --- python/llm/src/bigdl/llm/transformers/models/baichuan2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index c4b98a1d..5a0e537d 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -287,7 +287,7 @@ def baichuan_attention_forward_13b( ) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = torch.matmul(attn_weights.to(dtype=value_states.dtype), value_states) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)