qwen prefill attn_mask type fp16 (#12394)

This commit is contained in:
Yina Chen 2024-11-13 11:45:26 +02:00 committed by GitHub
parent 1158f91648
commit 9220babaab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -144,7 +144,7 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
(self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.int64) (self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.int64)
else: else:
attention_mask = self.create_input_op( attention_mask = self.create_input_op(
(self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.int64) (self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.float16)
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
@ -522,7 +522,7 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module):
backend_cls = self.backend_cls_prefill backend_cls = self.backend_cls_prefill
inputs = (hidden_states.to(torch.float16), inputs = (hidden_states.to(torch.float16),
attention_mask.to(torch.int64), attention_mask.to(torch.float16),
position_ids.to(torch.int64)) position_ids.to(torch.int64))
inputs += (self.layer_norm_0, self.layer_norm_1) inputs += (self.layer_norm_0, self.layer_norm_1)
inputs += (self.q_bias, self.k_bias, self.v_bias) inputs += (self.q_bias, self.k_bias, self.v_bias)