diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index c04bad74..e2dd913c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -122,7 +122,7 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory): # Self Attention if mode == "decode": attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1), - dtype=np.int64) + dtype=np.float16) else: attention_mask = None @@ -287,7 +287,6 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory): else: attn_weight = self.matmul(query_states, key_states, False, True) / ( math.sqrt(self.head_dim)) - attention_mask = self.convert_to_fp16(attention_mask) attn_weight = self.eltwise_add(attn_weight, attention_mask) attn_weight = self.convert_to_fp32(attn_weight) attn_weight = self.softmax(attn_weight, -1) @@ -451,7 +450,7 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module): inputs = ( hidden_states.to(torch.float16), - attention_mask.to(torch.int64), + attention_mask.to(torch.float16), position_ids.to(torch.int64), ) @@ -697,9 +696,9 @@ def run_decode( pad_mask = (0, pad_len) padded_causal_mask = F.pad( - attention_mask.to(torch.int64), pad_mask, value=torch.iinfo(torch.int64).min + attention_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min ) - padded_causal_mask[:, :, :, -1] = 0 + padded_causal_mask[:, :, :, -1] = 0.0 dist.recv(hidden_states, src=rank - 1) layer_outputs = multi_decoder( hidden_states, @@ -950,9 +949,9 @@ class PrefillRunner: hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0) position_ids = F.pad(position_ids, (0, pad_len), value=0) attention_mask = F.pad( - attention_mask.to(torch.int64), + attention_mask.to(torch.float16), (0, pad_len, 0, pad_len), - value=torch.iinfo(torch.int64).min, + value=torch.finfo(torch.float16).min, ) args = (hidden_states, position_ids, attention_mask, past_key_value) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index fda6530b..1f3ac302 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -113,14 +113,14 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): # Self Attention if mode == "decode": attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1), - dtype=np.int64) + dtype=np.float16) else: if use_prefill_sdp: attention_mask = None else: attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len), - dtype=np.int64) + dtype=np.float16) if self.cached_cos is None: if mode == "prefill": position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) @@ -364,7 +364,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): inputs = ( hidden_states.to(torch.float16), - attention_mask.to(torch.int64), + attention_mask.to(torch.float16), ) if self.cached_cos is None: @@ -494,7 +494,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): position_ids.to(torch.int64)) else: inputs = (hidden_states.to(torch.float16), - attention_mask.to(torch.int64), + attention_mask.to(torch.float16), position_ids.to(torch.int64)) if self.cached_cos is None: inputs += (cos.to(torch.float32), sin.to(torch.float32),) @@ -625,7 +625,7 @@ def run_decode( past_key_values = input_queue.get() else: past_seen_tokens = past_key_values.get_seq_length() - attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) + attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.float16) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + 1, device=hidden_states.device ) @@ -938,9 +938,9 @@ class PrefillRunner: hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0) position_ids = F.pad(position_ids, (0, pad_len), value=0) attention_mask = F.pad( - attention_mask.to(torch.int64), + attention_mask.to(torch.float16), (0, pad_len, 0, pad_len), - value=torch.iinfo(torch.int64).min, + value=torch.finfo(torch.float16).min, ) args = (hidden_states, position_ids, attention_mask, past_key_value, diff --git a/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py index e9fbfce1..bc0df951 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py @@ -125,10 +125,10 @@ class LowBitMinicpmMultiDecoderlayer(LLMBaseNNFactory): # Self Attention if mode == "decode": attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1), - dtype=np.int64) + dtype=np.float16) else: attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len), - dtype=np.int64) + dtype=np.float16) position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) @@ -357,7 +357,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): inputs = ( hidden_states.to(torch.float16), - attention_mask.to(torch.int64), + attention_mask.to(torch.float16), position_ids.to(torch.int64), ) @@ -475,7 +475,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): backend_cls = self.backend_cls_prefill inputs = (hidden_states.to(torch.float16), - attention_mask.to(torch.int64), + attention_mask.to(torch.float16), position_ids.to(torch.int64)) inputs += (self.layer_norm_0, self.layer_norm_1) hidden_states, past_key, past_value = run_model( @@ -599,7 +599,7 @@ def run_decode( past_key_values = input_queue.get() else: past_seen_tokens = past_key_values.get_seq_length() - attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) + attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.float16) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + 1, device=hidden_states.device ) @@ -878,9 +878,9 @@ class PrefillRunner: hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0) position_ids = F.pad(position_ids, (0, pad_len), value=0) attention_mask = F.pad( - attention_mask.to(torch.int64), + attention_mask.to(torch.float16), (0, pad_len, 0, pad_len), - value=torch.iinfo(torch.int64).min, + value=torch.finfo(torch.float16).min, ) args = (hidden_states, position_ids, attention_mask, past_key_value) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 0b008729..ccf6e242 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -247,8 +247,6 @@ class LLMBaseNNFactory(NNFactory): attn_weight = self.matmul(query_states, key_states, False, True) / ( math.sqrt(head_dim) ) - if mode != "prefill": - attention_mask = self.convert_to_fp16(attention_mask) attn_weight = self.eltwise_add(attn_weight, attention_mask) attn_weight = self.convert_to_fp32(attn_weight) attn_weight = self.softmax(attn_weight, -1) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index 58c2632f..015efe10 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -141,7 +141,7 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory): # Self Attention if mode == "decode": attention_mask = self.create_input_op( - (self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.int64) + (self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.float16) else: attention_mask = self.create_input_op( (self.batch_size, 1, self.seq_len, self.seq_len), dtype=np.float16) @@ -403,7 +403,7 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module): inputs = ( hidden_states.to(torch.float16), - attention_mask.to(torch.int64), + attention_mask.to(torch.float16), position_ids.to(torch.int64), ) @@ -649,7 +649,7 @@ def run_decode( past_key_values = input_queue.get() else: past_seen_tokens = past_key_values.get_seq_length() - attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) + attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.float16) position_ids = torch.arange( past_seen_tokens, 1 + past_seen_tokens, @@ -672,9 +672,9 @@ def run_decode( causal_mask[:, :, :, -1] = torch.finfo(torch.float16).min pad_mask = (0, pad_len) padded_causal_mask = F.pad( - causal_mask.to(torch.int64), pad_mask, value=torch.iinfo(torch.int64).min + causal_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min ) - padded_causal_mask[:, :, :, -1] = 0 + padded_causal_mask[:, :, :, -1] = 0.0 dist.recv(hidden_states, src=rank - 1) layer_outputs = multi_decoder( hidden_states,