From d872639395bd2dc2bb885522724cf33a5424061e Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Tue, 5 Nov 2024 09:51:31 +0200 Subject: [PATCH] [NPU] Llama3, Qwen2 1.5b, MiniCPM 1/2B groupwise support (#12327) * support minicpm 1b & qwen 1.5b gw * support minicpm 1b * support minicpm 2b * fix style & error * fix style & update * remove print --- .../LLM/Pipeline-Models/minicpm.py | 2 + .../transformers/npu_models/common.py | 10 +- .../transformers/npu_models/convert_mp.py | 26 +++- .../transformers/npu_models/llama_mp.py | 6 +- .../transformers/npu_models/minicpm_mp.py | 133 +++++++++++++++--- .../transformers/npu_models/mp_models_base.py | 1 - .../npu_pipeline_model/convert_pipeline.py | 3 + .../npu_pipeline_model/minicpm.py | 105 ++++++++++---- .../transformers/npu_pipeline_model/qwen.py | 21 ++- 9 files changed, 239 insertions(+), 68 deletions(-) diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/minicpm.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/minicpm.py index a84f78a7..d9bcae4b 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/minicpm.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/minicpm.py @@ -47,6 +47,7 @@ if __name__ == "__main__": parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") parser.add_argument("--max-context-len", type=int, default=1024) parser.add_argument("--max-prompt-len", type=int, default=512) + parser.add_argument("--quantization_group_size", type=int, default=0) parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--disable-streaming", action="store_true", default=False) @@ -61,6 +62,7 @@ if __name__ == "__main__": max_prompt_len=args.max_prompt_len, torch_dtype=torch.float16, attn_implementation="eager", + quantization_group_size=args.quantization_group_size, transpose_value_cache=not args.disable_transpose_value_cache, trust_remote_code=True) else: diff --git a/python/llm/src/ipex_llm/transformers/npu_models/common.py b/python/llm/src/ipex_llm/transformers/npu_models/common.py index 0ab4f5ae..45c12fc6 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/common.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/common.py @@ -76,13 +76,19 @@ def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"] mlp_module_names = ["down_proj", "up_proj", "gate_proj"] - if isinstance(module, (Qwen2Attention, LlamaAttention)): + if ( + isinstance(module, (Qwen2Attention, LlamaAttention)) + or module.__class__.__name__ in ['MiniCPMAttention', 'Attention'] + ): for name in attn_module_names: setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name, n_splits=n_splits_hidden_size, load=load)) delattr(module, name) - elif isinstance(module, (Qwen2MLP, LlamaMLP)): + elif ( + isinstance(module, (Qwen2MLP, LlamaMLP)) + or module.__class__.__name__ in ['MiniCPMMLP', 'MLP'] + ): for name in mlp_module_names: n_splits_mlp = n_splits_hidden_size if name == 'down_proj': diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index c10fbfc3..4b8581e4 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -87,9 +87,8 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision, model.llm.config.model_type = "llama" model = model.llm - if model.config.model_type in ["qwen2", "llama"]: + if model.config.model_type in ["qwen2", "llama", "minicpm"]: from ipex_llm.transformers.npu_models.common import split_linears - if quantization_group_size == 0: n_splits_linear = 1 n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1 @@ -110,10 +109,21 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision, if quantization_group_size != 0: split_num = model.config.hidden_size // quantization_group_size - new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num, - bias=model.lm_head.bias, use_split=True) - del model.lm_head - model.lm_head = new_lm_head + if model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40: + # workaround for MiniCPM-2B + new_lm_head_0 = SlicedLMHead(model.lm_head_0.weight, split_num=split_num, + bias=model.lm_head_0.bias, use_split=True) + del model.lm_head_0 + model.lm_head_0 = new_lm_head_0 + new_lm_head_1 = SlicedLMHead(model.lm_head_1.weight, split_num=split_num, + bias=model.lm_head_1.bias, use_split=True) + del model.lm_head_1 + model.lm_head_1 = new_lm_head_1 + else: + new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num, + bias=model.lm_head.bias, use_split=True) + del model.lm_head + model.lm_head = new_lm_head if model.config.model_type == "qwen2": # for Qwen2-7B-Insturct, divide lm_head into 14 parts @@ -372,6 +382,10 @@ def optimize_llm( transpose_value_cache=transpose_value_cache) if hasattr(model, 'lm_head') and isinstance(model.lm_head, SlicedLMHead): model.lm_head.get_fused_lm_head() + # MiniCPM-2b + if hasattr(model, "lm_head_1") and isinstance(model.lm_head_1, SlicedLMHead): + model.lm_head_1.get_fused_lm_head() + model.lm_head_0.get_fused_lm_head() def optimize_funasr( diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index b237f6cc..8373aab7 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -110,8 +110,8 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): # define input, the order self.parameter matters input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size)) - # llama2 use ov sdp, other models need to test - use_prefill_sdp = self.intermediate_size == 11008 + # llama2/3 use ov sdp, other models need to test + use_prefill_sdp = self.intermediate_size in [11008, 14336] # Self Attention if mode == "decode": @@ -437,7 +437,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): ) self.layer_norm_0 = layer_norm_0 self.layer_norm_1 = layer_norm_1 - self.use_prefill_sdp = intermediate_size == 11008 + self.use_prefill_sdp = intermediate_size in [11008, 14336] def forward( self, diff --git a/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py index 8e12582c..c35f687a 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py @@ -78,13 +78,19 @@ class LowBitMinicpmMultiDecoderlayer(LLMBaseNNFactory): rms_norm_eps, intermediate_size, scale_depth, - num_hidden_layers + num_hidden_layers, + n_splits_linear: int = 1, + n_splits_down_proj: int = 1, + group_size: int = 0 ): super().__init__(max_seq_len=max_seq_len, transpose_value=transpose_value, dtype=dtype, profile=profile, - device=device) + device=device, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size) self.max_seq_len = max_seq_len self.intermediate_size = intermediate_size self.dtype = dtype @@ -235,7 +241,7 @@ class LowBitMinicpmMultiDecoderlayer(LLMBaseNNFactory): attn_output * layer_scale_depth) residual = hidden_states hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight) - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states, self.seq_len, self.mode) hidden_states = self.eltwise_add(residual, hidden_states * layer_scale_depth) hidden_states = self.convert_to_fp16(hidden_states) @@ -264,6 +270,9 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): max_seq_len: int = 1024, transpose_value: bool = False, do_print: bool = False, + n_splits_linear: int = 1, + n_splits_down_proj: int = 1, + group_size: int = 0 ): super().__init__() @@ -273,6 +282,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): for w in parameters: if isinstance(w, tuple): # from QuantizedLinear op_parameters.append((w[0].numpy(), w[1].numpy())) + elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight + op_parameters.append(w.numpy()) + elif isinstance(w, np.ndarray): # scale + op_parameters.append(w) else: op_parameters.append(w.to(torch.float16).numpy()) self.op_parameters = op_parameters @@ -281,6 +294,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): self.transpose_value = transpose_value if isinstance(parameters[0], tuple): np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 + elif parameters[0].dtype == torch.int8: + np_dtype = np.int8 + elif parameters[0].dtype == torch.uint8: + np_dtype = np.uint8 else: # FP16 Linear np_dtype = np.float16 @@ -317,6 +334,9 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): mode="decode", transpose_value=self.transpose_value, dtype=np_dtype, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) self.backend_decoders.append(decoder) @@ -392,6 +412,9 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): num_hidden_layers, max_seq_len: int = 128, transpose_value: bool = False, + n_splits_linear: int = 1, + n_splits_down_proj: int = 1, + group_size: int = 0 ): super().__init__() self.op_parameters = parameters @@ -422,6 +445,9 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): mode="prefill", transpose_value=self.transpose_value, dtype=np_dtype, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) self.layer_norm_0 = layer_norm_0 self.layer_norm_1 = layer_norm_1 @@ -501,24 +527,53 @@ def run_decode( rms_norm_eps = model.config.rms_norm_eps intermediate_size = model.config.intermediate_size num_hidden_layers = model.config.num_hidden_layers + group_size = getattr(model.config, "group_size", 0) layer_weights = [] input_layer_norm_weights = [] post_attn_layernorm_weights = [] layer_indexs = range(layer_start, layer_end) + n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) + n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) for layer_idx in layer_indexs: curr_layer = model.model.layers[layer_idx] attn_layer = curr_layer.self_attn mlp_layer = curr_layer.mlp - weights = [ - (attn_layer.q_proj.weight, attn_layer.q_proj.scale), - (attn_layer.k_proj.weight, attn_layer.k_proj.scale), - (attn_layer.v_proj.weight, attn_layer.v_proj.scale), - (attn_layer.o_proj.weight, attn_layer.o_proj.scale), - (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), - (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), - (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), - ] + weights = [] + if n_splits_linear == 1: + for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, + attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, + attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, + mlp_layer.up_proj_dq_list): + weights.append((q.weight, q.scale)) + weights.append((k.weight, k.scale)) + weights.append((v.weight, v.scale)) + weights.append((o.weight, o.scale)) + weights.append((g.weight, g.scale)) + weights.append((u.weight, u.scale)) + else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + + if n_splits_down_proj == 1: + for l in mlp_layer.down_proj_dq_list: + weights.append((l.weight, l.scale)) + else: + l_weights = [] + scales = [] + for l in mlp_layer.down_proj_dq_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -547,6 +602,9 @@ def run_decode( max_seq_len=max_seq_len, transpose_value=transpose_value_cache, do_print=False, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) dist.barrier() @@ -711,25 +769,55 @@ def run_prefill( intermediate_size = model.config.intermediate_size scale_depth = model.config.scale_depth num_hidden_layers = model.config.num_hidden_layers + group_size = getattr(model.config, "group_size", 0) deocderlayers = [] layer_weights = [] input_layer_norm_weights = [] post_attn_layernorm_weights = [] layer_indexs = range(layer_start, layer_end) + n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) + n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) for layer_idx in layer_indexs: curr_layer = model.model.layers[layer_idx] attn_layer = curr_layer.self_attn mlp_layer = curr_layer.mlp - weights = [ - (attn_layer.q_proj.weight, attn_layer.q_proj.scale), - (attn_layer.k_proj.weight, attn_layer.k_proj.scale), - (attn_layer.v_proj.weight, attn_layer.v_proj.scale), - (attn_layer.o_proj.weight, attn_layer.o_proj.scale), - (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), - (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), - (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), - ] + weights = [] + + if n_splits_linear == 1: + for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, + attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, + attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, + mlp_layer.up_proj_dq_list): + weights.append((q.weight, q.scale)) + weights.append((k.weight, k.scale)) + weights.append((v.weight, v.scale)) + weights.append((o.weight, o.scale)) + weights.append((g.weight, g.scale)) + weights.append((u.weight, u.scale)) + else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + + if n_splits_down_proj == 1: + for l in mlp_layer.down_proj_dq_list: + weights.append((l.weight, l.scale)) + else: + l_weights = [] + scales = [] + for l in mlp_layer.down_proj_dq_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -752,6 +840,9 @@ def run_prefill( num_hidden_layers=num_hidden_layers, max_seq_len=max_output_len, transpose_value=transpose_value_cache, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) layer_weights.extend(weights) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 3ac026aa..5aa195a0 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -273,7 +273,6 @@ class LLMBaseNNFactory(NNFactory): self.n_splits_linear, wt_dtype=self.dtype, scale_factor=(self.group_size == 0), is_prefill=(mode == "prefill")) - return attn_output, new_key_states, new_value_states def paraformer_layer_norm(self, hidden_states, layernorm_weight, layernorm_bias): diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 4c49781c..539acd69 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -370,6 +370,9 @@ def convert_llm(model: torch.nn.Module, if hasattr(model, "lm_head") and isinstance(model.lm_head, SlicedLMHead): model.lm_head.get_fused_lm_head() + if hasattr(model, "lm_head_1") and isinstance(model.lm_head_1, SlicedLMHead): + model.lm_head_1.get_fused_lm_head() + model.lm_head_0.get_fused_lm_head() # patch generate function import types diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py index 07017efc..d4dbefdb 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py @@ -81,6 +81,7 @@ class MiniCPMLMHead(LLMBaseNNFactory): transpose_value: bool = False, profile: bool = False, device: str = "NPU", + n_splits: int = 1, ): super().__init__(max_seq_len=max_seq_len, transpose_value=transpose_value, @@ -108,19 +109,37 @@ class MiniCPMLMHead(LLMBaseNNFactory): hidden_states = self.layer_norm(hidden_states, model_norm_weight) if vocab_size == 122753: # for MiniCPM-2B-sft-bf16 - hidden_states_1 = self.linear( - hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype - ) - hidden_states_2 = self.linear( - hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype - ) + if n_splits == 1: + hidden_states_1 = self.linear( + hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype + ) + hidden_states_2 = self.linear( + hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype + ) + else: + hidden_states_1 = self.dq_split_linear( + hidden_states, 73440, self.hidden_size, + n_splits=n_splits, wt_dtype=dtype, scale_factor=False + ) + hidden_states_2 = self.dq_split_linear( + hidden_states, 73440, self.hidden_size, + n_splits=n_splits, wt_dtype=dtype, scale_factor=False + ) + hidden_states_2 = self.slice(hidden_states_2, begin=[0, 0, 0], end=[1, 1, 49313]) hidden_states = self.concat(hidden_states_1, hidden_states_2, axis=2) else: # for MiniCPM-1B-sft-bf16 - hidden_states = self.linear( - hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype - ) + if n_splits == 1: + hidden_states = self.linear( + hidden_states, self.vocab_size, self.hidden_size, bias=False, + wt_dtype=self.dtype + ) + else: + hidden_states = self.dq_split_linear( + hidden_states, self.vocab_size, self.hidden_size, + n_splits=n_splits, wt_dtype=dtype, scale_factor=False + ) # define outputs hidden_states = self.convert_to_fp32(hidden_states) @@ -145,8 +164,19 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): # for MiniCPM-1B-sft-bf16 weights = [(model.lm_head.weight, model.lm_head.scale)] else: - # TODO - pass + weights = [] + if vocab_size == 122753: + lm_head_list = [model.lm_head_0.lm_heads, model.lm_head_1.lm_heads] + else: + lm_head_list = [model.lm_head.lm_heads] + for lh in lm_head_list: + lm_head_weights = [] + scales = [] + for l in lh: + lm_head_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(lm_head_weights, axis=0), + torch.stack(scales, axis=0))) if isinstance(weights[0], tuple): np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 else: # FP16 Linear @@ -162,6 +192,7 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): dtype=np_dtype, model_norm_weight=model_norm.weight.to(torch.float16), vocab_size=vocab_size, + n_splits=n_splits_linear ) last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir) @@ -175,8 +206,9 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): else: weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy(), ] else: - # TODO - pass + weight_numpy = [v.numpy() for v in weights[0]] + if vocab_size == 122753: + weight_numpy.extend([v.numpy() for v in weights[1]]) for idx, weight in enumerate(weight_numpy): bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin") @@ -214,18 +246,40 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, weights = [] if n_splits_linear == 1: - weights = [ - (attn_layer.q_proj.weight, attn_layer.q_proj.scale), - (attn_layer.k_proj.weight, attn_layer.k_proj.scale), - (attn_layer.v_proj.weight, attn_layer.v_proj.scale), - (attn_layer.o_proj.weight, attn_layer.o_proj.scale), - (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), - (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), - (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), - ] + for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, + attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, + attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, + mlp_layer.up_proj_dq_list): + weights.append((q.weight, q.scale)) + weights.append((k.weight, k.scale)) + weights.append((v.weight, v.scale)) + weights.append((o.weight, o.scale)) + weights.append((g.weight, g.scale)) + weights.append((u.weight, u.scale)) else: - # TODO - pass + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), + torch.stack(scales, axis=0))) + + if n_splits_down_proj == 1: + for l in mlp_layer.down_proj_dq_list: + weights.append((l.weight, l.scale)) + else: + l_weights = [] + scales = [] + for l in mlp_layer.down_proj_dq_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -254,6 +308,9 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, mode="decode", transpose_value=transpose_value_cache, dtype=np_dtype, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, f"decoder_layer_{layer_idx}", diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py index 80f15aa4..eb38ad7b 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py @@ -19,6 +19,7 @@ import torch import numpy as np import os from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead +from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): @@ -27,18 +28,16 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): rms_norm_eps = model.config.rms_norm_eps vocab_size = model.config.vocab_size model_norm = model.model.norm - if model.config.intermediate_size == 18944: - lm_heads = model.lm_head.lm_heads # Qwen2-7B is always SlicedLMHead - else: - lm_heads = [model.lm_head] - if n_splits_linear == 1: - weights = [(lm_heads[0].weight, lm_heads[0].scale)] + lm_head = model.lm_head + if not isinstance(lm_head, SlicedLMHead): + weights = [(lm_head.weight, lm_head.scale)] else: + lm_heads = lm_head.lm_heads lm_head_weights = [] scales = [] - for i in range(n_splits_linear): - lm_head_weights.append(lm_heads[i].weight) - scales.append(lm_heads[i].scale) + for l in lm_heads: + lm_head_weights.append(l.weight) + scales.append(l.scale) weights = [(torch.stack(lm_head_weights, axis=0), torch.stack(scales, axis=0))] if isinstance(weights[0], tuple): @@ -61,9 +60,9 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir) # save weights bins files - if n_splits_linear == 1: + if not isinstance(lm_head, SlicedLMHead): weight_numpy = [ - lm_heads[0].weight.data.numpy(), lm_heads[0].scale.data.numpy(), + lm_head.weight.data.numpy(), lm_head.scale.data.numpy(), ] else: weight_numpy = [v.numpy() for v in weights[0]]