diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index c4b98a1d..5a0e537d 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -287,7 +287,7 @@ def baichuan_attention_forward_13b( ) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = torch.matmul(attn_weights.to(dtype=value_states.dtype), value_states) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)