qwen prefill attn_mask type fp16 (#12394)
This commit is contained in:
parent
1158f91648
commit
9220babaab
1 changed files with 2 additions and 2 deletions
|
|
@ -144,7 +144,7 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
(self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.int64)
|
(self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.int64)
|
||||||
else:
|
else:
|
||||||
attention_mask = self.create_input_op(
|
attention_mask = self.create_input_op(
|
||||||
(self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.int64)
|
(self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.float16)
|
||||||
|
|
||||||
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
|
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
|
||||||
|
|
||||||
|
|
@ -522,7 +522,7 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module):
|
||||||
|
|
||||||
backend_cls = self.backend_cls_prefill
|
backend_cls = self.backend_cls_prefill
|
||||||
inputs = (hidden_states.to(torch.float16),
|
inputs = (hidden_states.to(torch.float16),
|
||||||
attention_mask.to(torch.int64),
|
attention_mask.to(torch.float16),
|
||||||
position_ids.to(torch.int64))
|
position_ids.to(torch.int64))
|
||||||
inputs += (self.layer_norm_0, self.layer_norm_1)
|
inputs += (self.layer_norm_0, self.layer_norm_1)
|
||||||
inputs += (self.q_bias, self.k_bias, self.v_bias)
|
inputs += (self.q_bias, self.k_bias, self.v_bias)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue