diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan2.py b/python/llm/src/ipex_llm/transformers/models/baichuan2.py index f72adedb..74414ed8 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan2.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan2.py @@ -213,6 +213,7 @@ def baichuan_attention_forward_13b( attn_weights = attn_weights / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights = attn_weights + attention_mask + attn_weights = attn_weights.to(query_states.dtype) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) if use_quantize_kv and q_len == 1: import linear_q4_0