From 4e97047d70594ac6a25fe4e0eb4a1a78769db516 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Mon, 20 May 2024 11:21:20 +0800 Subject: [PATCH] fix baichuan2 13b fp16 (#11071) --- python/llm/src/ipex_llm/transformers/models/baichuan2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan2.py b/python/llm/src/ipex_llm/transformers/models/baichuan2.py index f72adedb..74414ed8 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan2.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan2.py @@ -213,6 +213,7 @@ def baichuan_attention_forward_13b( attn_weights = attn_weights / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights = attn_weights + attention_mask + attn_weights = attn_weights.to(query_states.dtype) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) if use_quantize_kv and q_len == 1: import linear_q4_0