From 9220babaab2b77131383efea677fe55221ed583a Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:45:26 +0200 Subject: [PATCH] qwen prefill attn_mask type fp16 (#12394) --- python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index eb001b6d..173271c5 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -144,7 +144,7 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory): (self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.int64) else: attention_mask = self.create_input_op( - (self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.int64) + (self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.float16) position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) @@ -522,7 +522,7 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module): backend_cls = self.backend_cls_prefill inputs = (hidden_states.to(torch.float16), - attention_mask.to(torch.int64), + attention_mask.to(torch.float16), position_ids.to(torch.int64)) inputs += (self.layer_norm_0, self.layer_norm_1) inputs += (self.q_bias, self.k_bias, self.v_bias)