[NPU] change attention_mask to fp16 (#12400)
This commit is contained in:
		
							parent
							
								
									7e50ff113c
								
							
						
					
					
						commit
						d4d949443f
					
				
					 5 changed files with 25 additions and 28 deletions
				
			
		| 
						 | 
				
			
			@ -122,7 +122,7 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        # Self Attention
 | 
			
		||||
        if mode == "decode":
 | 
			
		||||
            attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
 | 
			
		||||
                                                  dtype=np.int64)
 | 
			
		||||
                                                  dtype=np.float16)
 | 
			
		||||
        else:
 | 
			
		||||
            attention_mask = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -287,7 +287,6 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        else:
 | 
			
		||||
            attn_weight = self.matmul(query_states, key_states, False, True) / (
 | 
			
		||||
                math.sqrt(self.head_dim))
 | 
			
		||||
            attention_mask = self.convert_to_fp16(attention_mask)
 | 
			
		||||
            attn_weight = self.eltwise_add(attn_weight, attention_mask)
 | 
			
		||||
            attn_weight = self.convert_to_fp32(attn_weight)
 | 
			
		||||
            attn_weight = self.softmax(attn_weight, -1)
 | 
			
		||||
| 
						 | 
				
			
			@ -451,7 +450,7 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
        inputs = (
 | 
			
		||||
            hidden_states.to(torch.float16),
 | 
			
		||||
            attention_mask.to(torch.int64),
 | 
			
		||||
            attention_mask.to(torch.float16),
 | 
			
		||||
            position_ids.to(torch.int64),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -697,9 +696,9 @@ def run_decode(
 | 
			
		|||
 | 
			
		||||
                pad_mask = (0, pad_len)
 | 
			
		||||
                padded_causal_mask = F.pad(
 | 
			
		||||
                    attention_mask.to(torch.int64), pad_mask, value=torch.iinfo(torch.int64).min
 | 
			
		||||
                    attention_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min
 | 
			
		||||
                )
 | 
			
		||||
                padded_causal_mask[:, :, :, -1] = 0
 | 
			
		||||
                padded_causal_mask[:, :, :, -1] = 0.0
 | 
			
		||||
                dist.recv(hidden_states, src=rank - 1)
 | 
			
		||||
                layer_outputs = multi_decoder(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -950,9 +949,9 @@ class PrefillRunner:
 | 
			
		|||
        hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
 | 
			
		||||
        position_ids = F.pad(position_ids, (0, pad_len), value=0)
 | 
			
		||||
        attention_mask = F.pad(
 | 
			
		||||
            attention_mask.to(torch.int64),
 | 
			
		||||
            attention_mask.to(torch.float16),
 | 
			
		||||
            (0, pad_len, 0, pad_len),
 | 
			
		||||
            value=torch.iinfo(torch.int64).min,
 | 
			
		||||
            value=torch.finfo(torch.float16).min,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        args = (hidden_states, position_ids, attention_mask, past_key_value)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -113,14 +113,14 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        # Self Attention
 | 
			
		||||
        if mode == "decode":
 | 
			
		||||
            attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
 | 
			
		||||
                                                  dtype=np.int64)
 | 
			
		||||
                                                  dtype=np.float16)
 | 
			
		||||
        else:
 | 
			
		||||
            if use_prefill_sdp:
 | 
			
		||||
                attention_mask = None
 | 
			
		||||
            else:
 | 
			
		||||
                attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len,
 | 
			
		||||
                                                       self.seq_len),
 | 
			
		||||
                                                      dtype=np.int64)
 | 
			
		||||
                                                      dtype=np.float16)
 | 
			
		||||
        if self.cached_cos is None:
 | 
			
		||||
            if mode == "prefill":
 | 
			
		||||
                position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
			
		||||
| 
						 | 
				
			
			@ -364,7 +364,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
        inputs = (
 | 
			
		||||
            hidden_states.to(torch.float16),
 | 
			
		||||
            attention_mask.to(torch.int64),
 | 
			
		||||
            attention_mask.to(torch.float16),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if self.cached_cos is None:
 | 
			
		||||
| 
						 | 
				
			
			@ -494,7 +494,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
                      position_ids.to(torch.int64))
 | 
			
		||||
        else:
 | 
			
		||||
            inputs = (hidden_states.to(torch.float16),
 | 
			
		||||
                      attention_mask.to(torch.int64),
 | 
			
		||||
                      attention_mask.to(torch.float16),
 | 
			
		||||
                      position_ids.to(torch.int64))
 | 
			
		||||
        if self.cached_cos is None:
 | 
			
		||||
            inputs += (cos.to(torch.float32), sin.to(torch.float32),)
 | 
			
		||||
| 
						 | 
				
			
			@ -625,7 +625,7 @@ def run_decode(
 | 
			
		|||
                past_key_values = input_queue.get()
 | 
			
		||||
            else:
 | 
			
		||||
                past_seen_tokens = past_key_values.get_seq_length()
 | 
			
		||||
                attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64)
 | 
			
		||||
                attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.float16)
 | 
			
		||||
                cache_position = torch.arange(
 | 
			
		||||
                    past_seen_tokens, past_seen_tokens + 1, device=hidden_states.device
 | 
			
		||||
                )
 | 
			
		||||
| 
						 | 
				
			
			@ -938,9 +938,9 @@ class PrefillRunner:
 | 
			
		|||
        hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
 | 
			
		||||
        position_ids = F.pad(position_ids, (0, pad_len), value=0)
 | 
			
		||||
        attention_mask = F.pad(
 | 
			
		||||
            attention_mask.to(torch.int64),
 | 
			
		||||
            attention_mask.to(torch.float16),
 | 
			
		||||
            (0, pad_len, 0, pad_len),
 | 
			
		||||
            value=torch.iinfo(torch.int64).min,
 | 
			
		||||
            value=torch.finfo(torch.float16).min,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        args = (hidden_states, position_ids, attention_mask, past_key_value,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -125,10 +125,10 @@ class LowBitMinicpmMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        # Self Attention
 | 
			
		||||
        if mode == "decode":
 | 
			
		||||
            attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
 | 
			
		||||
                                                  dtype=np.int64)
 | 
			
		||||
                                                  dtype=np.float16)
 | 
			
		||||
        else:
 | 
			
		||||
            attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len),
 | 
			
		||||
                                                  dtype=np.int64)
 | 
			
		||||
                                                  dtype=np.float16)
 | 
			
		||||
 | 
			
		||||
        position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -357,7 +357,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
        inputs = (
 | 
			
		||||
            hidden_states.to(torch.float16),
 | 
			
		||||
            attention_mask.to(torch.int64),
 | 
			
		||||
            attention_mask.to(torch.float16),
 | 
			
		||||
            position_ids.to(torch.int64),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -475,7 +475,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
        backend_cls = self.backend_cls_prefill
 | 
			
		||||
        inputs = (hidden_states.to(torch.float16),
 | 
			
		||||
                  attention_mask.to(torch.int64),
 | 
			
		||||
                  attention_mask.to(torch.float16),
 | 
			
		||||
                  position_ids.to(torch.int64))
 | 
			
		||||
        inputs += (self.layer_norm_0, self.layer_norm_1)
 | 
			
		||||
        hidden_states, past_key, past_value = run_model(
 | 
			
		||||
| 
						 | 
				
			
			@ -599,7 +599,7 @@ def run_decode(
 | 
			
		|||
                past_key_values = input_queue.get()
 | 
			
		||||
            else:
 | 
			
		||||
                past_seen_tokens = past_key_values.get_seq_length()
 | 
			
		||||
                attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64)
 | 
			
		||||
                attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.float16)
 | 
			
		||||
                cache_position = torch.arange(
 | 
			
		||||
                    past_seen_tokens, past_seen_tokens + 1, device=hidden_states.device
 | 
			
		||||
                )
 | 
			
		||||
| 
						 | 
				
			
			@ -878,9 +878,9 @@ class PrefillRunner:
 | 
			
		|||
        hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
 | 
			
		||||
        position_ids = F.pad(position_ids, (0, pad_len), value=0)
 | 
			
		||||
        attention_mask = F.pad(
 | 
			
		||||
            attention_mask.to(torch.int64),
 | 
			
		||||
            attention_mask.to(torch.float16),
 | 
			
		||||
            (0, pad_len, 0, pad_len),
 | 
			
		||||
            value=torch.iinfo(torch.int64).min,
 | 
			
		||||
            value=torch.finfo(torch.float16).min,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        args = (hidden_states, position_ids, attention_mask, past_key_value)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -247,8 +247,6 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
            attn_weight = self.matmul(query_states, key_states, False, True) / (
 | 
			
		||||
                math.sqrt(head_dim)
 | 
			
		||||
            )
 | 
			
		||||
            if mode != "prefill":
 | 
			
		||||
                attention_mask = self.convert_to_fp16(attention_mask)
 | 
			
		||||
            attn_weight = self.eltwise_add(attn_weight, attention_mask)
 | 
			
		||||
            attn_weight = self.convert_to_fp32(attn_weight)
 | 
			
		||||
            attn_weight = self.softmax(attn_weight, -1)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -141,7 +141,7 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        # Self Attention
 | 
			
		||||
        if mode == "decode":
 | 
			
		||||
            attention_mask = self.create_input_op(
 | 
			
		||||
                (self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.int64)
 | 
			
		||||
                (self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.float16)
 | 
			
		||||
        else:
 | 
			
		||||
            attention_mask = self.create_input_op(
 | 
			
		||||
                (self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -403,7 +403,7 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
        inputs = (
 | 
			
		||||
            hidden_states.to(torch.float16),
 | 
			
		||||
            attention_mask.to(torch.int64),
 | 
			
		||||
            attention_mask.to(torch.float16),
 | 
			
		||||
            position_ids.to(torch.int64),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -649,7 +649,7 @@ def run_decode(
 | 
			
		|||
                past_key_values = input_queue.get()
 | 
			
		||||
            else:
 | 
			
		||||
                past_seen_tokens = past_key_values.get_seq_length()
 | 
			
		||||
                attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64)
 | 
			
		||||
                attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.float16)
 | 
			
		||||
                position_ids = torch.arange(
 | 
			
		||||
                    past_seen_tokens,
 | 
			
		||||
                    1 + past_seen_tokens,
 | 
			
		||||
| 
						 | 
				
			
			@ -672,9 +672,9 @@ def run_decode(
 | 
			
		|||
                causal_mask[:, :, :, -1] = torch.finfo(torch.float16).min
 | 
			
		||||
                pad_mask = (0, pad_len)
 | 
			
		||||
                padded_causal_mask = F.pad(
 | 
			
		||||
                    causal_mask.to(torch.int64), pad_mask, value=torch.iinfo(torch.int64).min
 | 
			
		||||
                    causal_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min
 | 
			
		||||
                )
 | 
			
		||||
                padded_causal_mask[:, :, :, -1] = 0
 | 
			
		||||
                padded_causal_mask[:, :, :, -1] = 0.0
 | 
			
		||||
                dist.recv(hidden_states, src=rank - 1)
 | 
			
		||||
                layer_outputs = multi_decoder(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue