diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index fdfdc528..453aee0e 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -396,7 +396,7 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module): inputs = ( hidden_states.to(torch.float16), attention_mask, - position_ids, + position_ids.to(torch.float16), ) for i in range(self.intra_stages): @@ -502,7 +502,7 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module): seq_len = hidden_states.shape[1] backend_cls = self.backend_cls_prefill - inputs = (hidden_states.to(torch.float16), attention_mask, position_ids) + inputs = (hidden_states.to(torch.float16), attention_mask, position_ids.to(torch.float16)) inputs += (self.layer_norm_0, self.layer_norm_1) hidden_states, past_key, past_value = run_model( inputs, self.op_parameters, backend_cls, self.op_id, replica=2 diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index b375c305..6039d94d 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -106,31 +106,13 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): # Self Attention if mode == "decode": - attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1)) + attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1), + dtype=np.int64) else: - attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len)) + attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len), + dtype=np.int64) - position_ids = self.create_input_op((self.batch_size, self.seq_len)) - past_keys = [] - past_values = [] - if mode == "decode": - for i in range(num_layers): - past_key = self.create_cache_op( - (self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim) - ) - if transpose_value: - past_value = self.create_cache_op( - (self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len) - ) - else: - past_value = self.create_cache_op( - (self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim) - ) - past_keys.append(past_key) - past_values.append(past_value) - else: - past_keys = [None] * num_layers - past_values = [None] * num_layers + position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) if input_layernorm_weights is None: input_layernorm_weights = [] @@ -156,6 +138,27 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): input_layernorm_weights = [self.constant(w) for w in input_layernorm_weights] post_attn_layernorm_weights = [self.constant(w) for w in post_attn_layernorm_weights] + past_keys = [] + past_values = [] + if mode == "decode": + for i in range(num_layers): + past_key = self.create_cache_op( + (self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim) + ) + if transpose_value: + past_value = self.create_cache_op( + (self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len) + ) + else: + past_value = self.create_cache_op( + (self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim) + ) + past_keys.append(past_key) + past_values.append(past_value) + else: + past_keys = [None] * num_layers + past_values = [None] * num_layers + hidden_states = input curr_key_values = [] @@ -310,8 +313,8 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): inputs = ( hidden_states.to(torch.float16), - attention_mask, - position_ids, + attention_mask.to(torch.int64), + position_ids.to(torch.int64), ) for i in range(self.intra_stages): @@ -419,7 +422,9 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): seq_len = hidden_states.shape[1] backend_cls = self.backend_cls_prefill - inputs = (hidden_states.to(torch.float16), attention_mask, position_ids) + inputs = (hidden_states.to(torch.float16), + attention_mask.to(torch.int64), + position_ids.to(torch.int64)) inputs += (self.layer_norm_0, self.layer_norm_1) hidden_states, past_key, past_value = run_model( inputs, self.op_parameters, backend_cls, self.op_id, replica=2 @@ -544,9 +549,9 @@ def run_decode( pad_mask = (0, pad_len) padded_causal_mask = F.pad( - causal_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min + causal_mask.to(torch.int64), pad_mask, value=torch.iinfo(torch.int64).min ) - padded_causal_mask[:, :, :, -1] = 0.0 + padded_causal_mask[:, :, :, -1] = 0 dist.recv(hidden_states, src=rank - 1) layer_outputs = multi_decoder( hidden_states, @@ -796,9 +801,9 @@ class PrefillRunner: hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0) position_ids = F.pad(position_ids, (0, pad_len), value=0) attention_mask = F.pad( - attention_mask.to(torch.float16), + attention_mask.to(torch.int64), (0, pad_len, 0, pad_len), - value=torch.finfo(torch.float16).min, + value=torch.iinfo(torch.int64).min, ) args = (hidden_states, position_ids, attention_mask, past_key_value, cache_position) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py index 55299608..15f545fa 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py @@ -335,7 +335,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): inputs = ( hidden_states.to(torch.float16), attention_mask, - position_ids, + position_ids.to(torch.float16), ) for i in range(self.intra_stages): @@ -445,7 +445,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): seq_len = hidden_states.shape[1] backend_cls = self.backend_cls_prefill - inputs = (hidden_states.to(torch.float16), attention_mask, position_ids) + inputs = (hidden_states.to(torch.float16), attention_mask, position_ids.to(torch.float16)) inputs += (self.layer_norm_0, self.layer_norm_1) hidden_states, past_key, past_value = run_model( inputs, self.op_parameters, backend_cls, self.op_id, replica=2 diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index d918e4a7..48b60baf 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -54,7 +54,8 @@ def run_model( # Reshape input input_dtype = x[0].dtype - x_np = [set_contiguous(elem).to(torch.float16).numpy() for elem in x] + x_np = [set_contiguous(elem).numpy() if elem.dtype == torch.int64 else + set_contiguous(elem).to(torch.float16).numpy() for elem in x] op_args = [] op_args_flatten = [] for w in weights: @@ -279,6 +280,7 @@ class LLMBaseNNFactory(NNFactory): attn_weight = self.matmul(query_states, key_states, False, True) / ( math.sqrt(head_dim) ) + attention_mask = self.convert_to_fp16(attention_mask) attn_weight = self.eltwise_add(attn_weight, attention_mask) attn_weight = self.convert_to_fp32(attn_weight) attn_weight = self.softmax(attn_weight, -1) @@ -476,13 +478,13 @@ class LLMBaseNNFactory(NNFactory): self.cache_parameter_ops.append(op) return op - def create_input_op(self, shape): + def create_input_op(self, shape, dtype=np.float16): invalidInputError(len(self.cache_parameter_ops) == 0, "create_input_op should be called before any create_cache_op") invalidInputError(len(self.linear_ops) == 0, "create_input_op should be called before any linear op") - op = super().parameter(shape) + op = super().parameter(shape, dtype) self.input_ops.append(op) return op @@ -563,7 +565,8 @@ class LLMBaseNNFactory(NNFactory): @staticmethod def run_decoders(inputs, decoders, models_ptr=None): - x_np = [elem.to(torch.float16).numpy() for elem in inputs] + x_np = [elem.numpy() if elem.dtype == torch.int64 else + elem.to(torch.float16).numpy() for elem in inputs] num_decoders = len(decoders) num_inputs = len(x_np) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index b4ad6770..e4092576 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -413,7 +413,7 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module): inputs = ( hidden_states.to(torch.float16), attention_mask, - position_ids, + position_ids.to(torch.float16), ) for i in range(self.intra_stages): @@ -530,7 +530,7 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module): seq_len = hidden_states.shape[1] backend_cls = self.backend_cls_prefill - inputs = (hidden_states.to(torch.float16), attention_mask, position_ids) + inputs = (hidden_states.to(torch.float16), attention_mask, position_ids.to(torch.float16)) inputs += (self.layer_norm_0, self.layer_norm_1) inputs += (self.q_bias, self.k_bias, self.v_bias) hidden_states, past_key, past_value = run_model(