[NPU] change attention_mask to fp16 (#12400)

This commit is contained in:
binbin Deng 2024-11-14 17:20:29 +08:00 committed by GitHub
parent 7e50ff113c
commit d4d949443f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 25 additions and 28 deletions

View file

@ -122,7 +122,7 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
# Self Attention # Self Attention
if mode == "decode": if mode == "decode":
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1), attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
dtype=np.int64) dtype=np.float16)
else: else:
attention_mask = None attention_mask = None
@ -287,7 +287,6 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
else: else:
attn_weight = self.matmul(query_states, key_states, False, True) / ( attn_weight = self.matmul(query_states, key_states, False, True) / (
math.sqrt(self.head_dim)) math.sqrt(self.head_dim))
attention_mask = self.convert_to_fp16(attention_mask)
attn_weight = self.eltwise_add(attn_weight, attention_mask) attn_weight = self.eltwise_add(attn_weight, attention_mask)
attn_weight = self.convert_to_fp32(attn_weight) attn_weight = self.convert_to_fp32(attn_weight)
attn_weight = self.softmax(attn_weight, -1) attn_weight = self.softmax(attn_weight, -1)
@ -451,7 +450,7 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
inputs = ( inputs = (
hidden_states.to(torch.float16), hidden_states.to(torch.float16),
attention_mask.to(torch.int64), attention_mask.to(torch.float16),
position_ids.to(torch.int64), position_ids.to(torch.int64),
) )
@ -697,9 +696,9 @@ def run_decode(
pad_mask = (0, pad_len) pad_mask = (0, pad_len)
padded_causal_mask = F.pad( padded_causal_mask = F.pad(
attention_mask.to(torch.int64), pad_mask, value=torch.iinfo(torch.int64).min attention_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min
) )
padded_causal_mask[:, :, :, -1] = 0 padded_causal_mask[:, :, :, -1] = 0.0
dist.recv(hidden_states, src=rank - 1) dist.recv(hidden_states, src=rank - 1)
layer_outputs = multi_decoder( layer_outputs = multi_decoder(
hidden_states, hidden_states,
@ -950,9 +949,9 @@ class PrefillRunner:
hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0) hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
position_ids = F.pad(position_ids, (0, pad_len), value=0) position_ids = F.pad(position_ids, (0, pad_len), value=0)
attention_mask = F.pad( attention_mask = F.pad(
attention_mask.to(torch.int64), attention_mask.to(torch.float16),
(0, pad_len, 0, pad_len), (0, pad_len, 0, pad_len),
value=torch.iinfo(torch.int64).min, value=torch.finfo(torch.float16).min,
) )
args = (hidden_states, position_ids, attention_mask, past_key_value) args = (hidden_states, position_ids, attention_mask, past_key_value)

View file

@ -113,14 +113,14 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
# Self Attention # Self Attention
if mode == "decode": if mode == "decode":
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1), attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
dtype=np.int64) dtype=np.float16)
else: else:
if use_prefill_sdp: if use_prefill_sdp:
attention_mask = None attention_mask = None
else: else:
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len,
self.seq_len), self.seq_len),
dtype=np.int64) dtype=np.float16)
if self.cached_cos is None: if self.cached_cos is None:
if mode == "prefill": if mode == "prefill":
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
@ -364,7 +364,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
inputs = ( inputs = (
hidden_states.to(torch.float16), hidden_states.to(torch.float16),
attention_mask.to(torch.int64), attention_mask.to(torch.float16),
) )
if self.cached_cos is None: if self.cached_cos is None:
@ -494,7 +494,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
position_ids.to(torch.int64)) position_ids.to(torch.int64))
else: else:
inputs = (hidden_states.to(torch.float16), inputs = (hidden_states.to(torch.float16),
attention_mask.to(torch.int64), attention_mask.to(torch.float16),
position_ids.to(torch.int64)) position_ids.to(torch.int64))
if self.cached_cos is None: if self.cached_cos is None:
inputs += (cos.to(torch.float32), sin.to(torch.float32),) inputs += (cos.to(torch.float32), sin.to(torch.float32),)
@ -625,7 +625,7 @@ def run_decode(
past_key_values = input_queue.get() past_key_values = input_queue.get()
else: else:
past_seen_tokens = past_key_values.get_seq_length() past_seen_tokens = past_key_values.get_seq_length()
attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.float16)
cache_position = torch.arange( cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + 1, device=hidden_states.device past_seen_tokens, past_seen_tokens + 1, device=hidden_states.device
) )
@ -938,9 +938,9 @@ class PrefillRunner:
hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0) hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
position_ids = F.pad(position_ids, (0, pad_len), value=0) position_ids = F.pad(position_ids, (0, pad_len), value=0)
attention_mask = F.pad( attention_mask = F.pad(
attention_mask.to(torch.int64), attention_mask.to(torch.float16),
(0, pad_len, 0, pad_len), (0, pad_len, 0, pad_len),
value=torch.iinfo(torch.int64).min, value=torch.finfo(torch.float16).min,
) )
args = (hidden_states, position_ids, attention_mask, past_key_value, args = (hidden_states, position_ids, attention_mask, past_key_value,

View file

@ -125,10 +125,10 @@ class LowBitMinicpmMultiDecoderlayer(LLMBaseNNFactory):
# Self Attention # Self Attention
if mode == "decode": if mode == "decode":
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1), attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
dtype=np.int64) dtype=np.float16)
else: else:
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len), attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len),
dtype=np.int64) dtype=np.float16)
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
@ -357,7 +357,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
inputs = ( inputs = (
hidden_states.to(torch.float16), hidden_states.to(torch.float16),
attention_mask.to(torch.int64), attention_mask.to(torch.float16),
position_ids.to(torch.int64), position_ids.to(torch.int64),
) )
@ -475,7 +475,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
backend_cls = self.backend_cls_prefill backend_cls = self.backend_cls_prefill
inputs = (hidden_states.to(torch.float16), inputs = (hidden_states.to(torch.float16),
attention_mask.to(torch.int64), attention_mask.to(torch.float16),
position_ids.to(torch.int64)) position_ids.to(torch.int64))
inputs += (self.layer_norm_0, self.layer_norm_1) inputs += (self.layer_norm_0, self.layer_norm_1)
hidden_states, past_key, past_value = run_model( hidden_states, past_key, past_value = run_model(
@ -599,7 +599,7 @@ def run_decode(
past_key_values = input_queue.get() past_key_values = input_queue.get()
else: else:
past_seen_tokens = past_key_values.get_seq_length() past_seen_tokens = past_key_values.get_seq_length()
attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.float16)
cache_position = torch.arange( cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + 1, device=hidden_states.device past_seen_tokens, past_seen_tokens + 1, device=hidden_states.device
) )
@ -878,9 +878,9 @@ class PrefillRunner:
hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0) hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
position_ids = F.pad(position_ids, (0, pad_len), value=0) position_ids = F.pad(position_ids, (0, pad_len), value=0)
attention_mask = F.pad( attention_mask = F.pad(
attention_mask.to(torch.int64), attention_mask.to(torch.float16),
(0, pad_len, 0, pad_len), (0, pad_len, 0, pad_len),
value=torch.iinfo(torch.int64).min, value=torch.finfo(torch.float16).min,
) )
args = (hidden_states, position_ids, attention_mask, past_key_value) args = (hidden_states, position_ids, attention_mask, past_key_value)

View file

@ -247,8 +247,6 @@ class LLMBaseNNFactory(NNFactory):
attn_weight = self.matmul(query_states, key_states, False, True) / ( attn_weight = self.matmul(query_states, key_states, False, True) / (
math.sqrt(head_dim) math.sqrt(head_dim)
) )
if mode != "prefill":
attention_mask = self.convert_to_fp16(attention_mask)
attn_weight = self.eltwise_add(attn_weight, attention_mask) attn_weight = self.eltwise_add(attn_weight, attention_mask)
attn_weight = self.convert_to_fp32(attn_weight) attn_weight = self.convert_to_fp32(attn_weight)
attn_weight = self.softmax(attn_weight, -1) attn_weight = self.softmax(attn_weight, -1)

View file

@ -141,7 +141,7 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
# Self Attention # Self Attention
if mode == "decode": if mode == "decode":
attention_mask = self.create_input_op( attention_mask = self.create_input_op(
(self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.int64) (self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.float16)
else: else:
attention_mask = self.create_input_op( attention_mask = self.create_input_op(
(self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.float16) (self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.float16)
@ -403,7 +403,7 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
inputs = ( inputs = (
hidden_states.to(torch.float16), hidden_states.to(torch.float16),
attention_mask.to(torch.int64), attention_mask.to(torch.float16),
position_ids.to(torch.int64), position_ids.to(torch.int64),
) )
@ -649,7 +649,7 @@ def run_decode(
past_key_values = input_queue.get() past_key_values = input_queue.get()
else: else:
past_seen_tokens = past_key_values.get_seq_length() past_seen_tokens = past_key_values.get_seq_length()
attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.float16)
position_ids = torch.arange( position_ids = torch.arange(
past_seen_tokens, past_seen_tokens,
1 + past_seen_tokens, 1 + past_seen_tokens,
@ -672,9 +672,9 @@ def run_decode(
causal_mask[:, :, :, -1] = torch.finfo(torch.float16).min causal_mask[:, :, :, -1] = torch.finfo(torch.float16).min
pad_mask = (0, pad_len) pad_mask = (0, pad_len)
padded_causal_mask = F.pad( padded_causal_mask = F.pad(
causal_mask.to(torch.int64), pad_mask, value=torch.iinfo(torch.int64).min causal_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min
) )
padded_causal_mask[:, :, :, -1] = 0 padded_causal_mask[:, :, :, -1] = 0.0
dist.recv(hidden_states, src=rank - 1) dist.recv(hidden_states, src=rank - 1)
layer_outputs = multi_decoder( layer_outputs = multi_decoder(
hidden_states, hidden_states,