qwen prefill attn_mask type fp16 (#12394)
This commit is contained in:
		
							parent
							
								
									1158f91648
								
							
						
					
					
						commit
						9220babaab
					
				
					 1 changed files with 2 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -144,7 +144,7 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
                (self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.int64)
 | 
			
		||||
        else:
 | 
			
		||||
            attention_mask = self.create_input_op(
 | 
			
		||||
                (self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.int64)
 | 
			
		||||
                (self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.float16)
 | 
			
		||||
 | 
			
		||||
        position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -522,7 +522,7 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
        backend_cls = self.backend_cls_prefill
 | 
			
		||||
        inputs = (hidden_states.to(torch.float16),
 | 
			
		||||
                  attention_mask.to(torch.int64),
 | 
			
		||||
                  attention_mask.to(torch.float16),
 | 
			
		||||
                  position_ids.to(torch.int64))
 | 
			
		||||
        inputs += (self.layer_norm_0, self.layer_norm_1)
 | 
			
		||||
        inputs += (self.q_bias, self.k_bias, self.v_bias)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue