[NPU] support asym_int4 for minicpm (#12567)
This commit is contained in:
		
							parent
							
								
									6e801bc4e1
								
							
						
					
					
						commit
						1a2ab12876
					
				
					 2 changed files with 146 additions and 40 deletions
				
			
		| 
						 | 
				
			
			@ -81,7 +81,8 @@ class LowBitMinicpmMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        num_hidden_layers,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
        group_size: int = 0,
 | 
			
		||||
        asym: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(max_seq_len=max_seq_len,
 | 
			
		||||
                         transpose_value=transpose_value,
 | 
			
		||||
| 
						 | 
				
			
			@ -90,7 +91,8 @@ class LowBitMinicpmMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
                         device=device,
 | 
			
		||||
                         n_splits_linear=n_splits_linear,
 | 
			
		||||
                         n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
                         group_size=group_size)
 | 
			
		||||
                         group_size=group_size,
 | 
			
		||||
                         asym=asym)
 | 
			
		||||
        self.max_seq_len = max_seq_len
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.dtype = dtype
 | 
			
		||||
| 
						 | 
				
			
			@ -272,7 +274,8 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        do_print: bool = False,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
        group_size: int = 0,
 | 
			
		||||
        asym: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -280,8 +283,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
        op_parameters = []
 | 
			
		||||
        for w in parameters:
 | 
			
		||||
            if isinstance(w, tuple):  # from QuantizedLinear
 | 
			
		||||
            if isinstance(w, tuple) and not asym:  # from QuantizedLinear
 | 
			
		||||
                op_parameters.append((w[0].numpy(), w[1].numpy()))
 | 
			
		||||
            elif isinstance(w, tuple) and asym:  # from QuantizedLinear
 | 
			
		||||
                op_parameters.append((w[0].numpy(), w[1].numpy(),  w[2].numpy()))
 | 
			
		||||
            elif w.dtype in [torch.int8, torch.uint8]:    # QuantizedLinear weight
 | 
			
		||||
                op_parameters.append(w.numpy())
 | 
			
		||||
            elif isinstance(w, np.ndarray):     # scale
 | 
			
		||||
| 
						 | 
				
			
			@ -336,7 +341,8 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
                dtype=np_dtype,
 | 
			
		||||
                n_splits_linear=n_splits_linear,
 | 
			
		||||
                n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
                group_size=group_size
 | 
			
		||||
                group_size=group_size,
 | 
			
		||||
                asym=asym,
 | 
			
		||||
            )
 | 
			
		||||
            self.backend_decoders.append(decoder)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -414,7 +420,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        transpose_value: bool = False,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
        group_size: int = 0,
 | 
			
		||||
        asym: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.op_parameters = parameters
 | 
			
		||||
| 
						 | 
				
			
			@ -447,7 +454,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
            dtype=np_dtype,
 | 
			
		||||
            n_splits_linear=n_splits_linear,
 | 
			
		||||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size
 | 
			
		||||
            group_size=group_size,
 | 
			
		||||
            asym=asym,
 | 
			
		||||
        )
 | 
			
		||||
        self.layer_norm_0 = layer_norm_0
 | 
			
		||||
        self.layer_norm_1 = layer_norm_1
 | 
			
		||||
| 
						 | 
				
			
			@ -534,6 +542,7 @@ def run_decode(
 | 
			
		|||
    layer_indexs = range(layer_start, layer_end)
 | 
			
		||||
    n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
			
		||||
    n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
			
		||||
    asym = getattr(model.config, "asym", False)
 | 
			
		||||
    for layer_idx in layer_indexs:
 | 
			
		||||
        curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
        attn_layer = curr_layer.self_attn
 | 
			
		||||
| 
						 | 
				
			
			@ -546,10 +555,17 @@ def run_decode(
 | 
			
		|||
                           mlp_layer.down_proj_dq_list]:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            zeros = []
 | 
			
		||||
            for l in layer_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
                if l.zero is not None:
 | 
			
		||||
                    zeros.append(l.zero)
 | 
			
		||||
            if len(zeros):
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
 | 
			
		||||
                                torch.stack(zeros, axis=0)))
 | 
			
		||||
            else:
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -580,7 +596,8 @@ def run_decode(
 | 
			
		|||
        do_print=False,
 | 
			
		||||
        n_splits_linear=n_splits_linear,
 | 
			
		||||
        n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
        group_size=group_size
 | 
			
		||||
        group_size=group_size,
 | 
			
		||||
        asym=asym,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    dist.barrier()
 | 
			
		||||
| 
						 | 
				
			
			@ -753,6 +770,7 @@ def run_prefill(
 | 
			
		|||
    layer_indexs = range(layer_start, layer_end)
 | 
			
		||||
    n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
			
		||||
    n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
			
		||||
    asym = getattr(model.config, "asym", False)
 | 
			
		||||
    for layer_idx in layer_indexs:
 | 
			
		||||
        curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
        attn_layer = curr_layer.self_attn
 | 
			
		||||
| 
						 | 
				
			
			@ -765,10 +783,17 @@ def run_prefill(
 | 
			
		|||
                           mlp_layer.down_proj_dq_list]:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            zeros = []
 | 
			
		||||
            for l in layer_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
                if l.zero is not None:
 | 
			
		||||
                    zeros.append(l.zero)
 | 
			
		||||
            if len(zeros):
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
 | 
			
		||||
                                torch.stack(zeros, axis=0)))
 | 
			
		||||
            else:
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -793,7 +818,8 @@ def run_prefill(
 | 
			
		|||
            transpose_value=transpose_value_cache,
 | 
			
		||||
            n_splits_linear=n_splits_linear,
 | 
			
		||||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size
 | 
			
		||||
            group_size=group_size,
 | 
			
		||||
            asym=asym
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        layer_weights.extend(weights)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -105,6 +105,7 @@ class MiniCPMLMHead(LLMBaseNNFactory):
 | 
			
		|||
        profile: bool = False,
 | 
			
		||||
        device: str = "NPU",
 | 
			
		||||
        n_splits: int = 1,
 | 
			
		||||
        asym: bool = False
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(max_seq_len=max_seq_len,
 | 
			
		||||
                         transpose_value=transpose_value,
 | 
			
		||||
| 
						 | 
				
			
			@ -134,11 +135,13 @@ class MiniCPMLMHead(LLMBaseNNFactory):
 | 
			
		|||
            # for MiniCPM-2B-sft-bf16
 | 
			
		||||
            hidden_states_1 = self.linear(
 | 
			
		||||
                hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype,
 | 
			
		||||
                n_splits=n_splits, scale_factor=(n_splits == 1)
 | 
			
		||||
                n_splits=n_splits, scale_factor=(n_splits == 1),
 | 
			
		||||
                asym=asym
 | 
			
		||||
            )
 | 
			
		||||
            hidden_states_2 = self.linear(
 | 
			
		||||
                hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype,
 | 
			
		||||
                n_splits=n_splits, scale_factor=(n_splits == 1)
 | 
			
		||||
                n_splits=n_splits, scale_factor=(n_splits == 1),
 | 
			
		||||
                asym=asym
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            hidden_states_2 = self.slice(hidden_states_2, begin=[0, 0, 0], end=[1, 1, 49313])
 | 
			
		||||
| 
						 | 
				
			
			@ -147,7 +150,8 @@ class MiniCPMLMHead(LLMBaseNNFactory):
 | 
			
		|||
            # for MiniCPM-1B-sft-bf16
 | 
			
		||||
            hidden_states = self.linear(
 | 
			
		||||
                hidden_states, self.vocab_size, self.hidden_size, bias=False,
 | 
			
		||||
                wt_dtype=self.dtype, n_splits=n_splits, scale_factor=(n_splits == 1)
 | 
			
		||||
                wt_dtype=self.dtype, n_splits=n_splits, scale_factor=(n_splits == 1),
 | 
			
		||||
                asym=asym
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # define outputs
 | 
			
		||||
| 
						 | 
				
			
			@ -165,28 +169,48 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
 | 
			
		|||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    vocab_size = model.config.vocab_size
 | 
			
		||||
    model_norm = model.model.norm
 | 
			
		||||
    asym = getattr(model.config, "asym", False)
 | 
			
		||||
    if n_splits_linear == 1:
 | 
			
		||||
        if vocab_size == 122753:
 | 
			
		||||
            # for MiniCPM-2B-sft-bf16
 | 
			
		||||
            weights = [(model.lm_head_0.weight, model.lm_head_0.scale),
 | 
			
		||||
                       (model.lm_head_1.weight, model.lm_head_1.scale)]
 | 
			
		||||
            asym = model.lm_head_0.qtype == "asym_int4_rtn"
 | 
			
		||||
            if asym:
 | 
			
		||||
                weights = [(model.lm_head_0.weight, model.lm_head_0.scale, model.lm_head_0.zero),
 | 
			
		||||
                           (model.lm_head_1.weight, model.lm_head_1.scale, model.lm_head_1.zero)]
 | 
			
		||||
            else:
 | 
			
		||||
                weights = [(model.lm_head_0.weight, model.lm_head_0.scale),
 | 
			
		||||
                           (model.lm_head_1.weight, model.lm_head_1.scale)]
 | 
			
		||||
        else:
 | 
			
		||||
            # for MiniCPM-1B-sft-bf16
 | 
			
		||||
            weights = [(model.lm_head.weight, model.lm_head.scale)]
 | 
			
		||||
            asym = model.lm_head.qtype == "asym_int4_rtn"
 | 
			
		||||
            if asym:
 | 
			
		||||
                weights = [(model.lm_head.weight, model.lm_head.scale, model.lm_head.zero)]
 | 
			
		||||
            else:
 | 
			
		||||
                weights = [(model.lm_head.weight, model.lm_head.scale)]
 | 
			
		||||
    else:
 | 
			
		||||
        weights = []
 | 
			
		||||
        if vocab_size == 122753:
 | 
			
		||||
            asym = model.lm_head_0.lm_heads[0].qtype == "asym_int4_rtn"
 | 
			
		||||
            lm_head_list = [model.lm_head_0.lm_heads, model.lm_head_1.lm_heads]
 | 
			
		||||
        else:
 | 
			
		||||
            asym = model.lm_head.lm_heads[0].qtype == "asym_int4_rtn"
 | 
			
		||||
            lm_head_list = [model.lm_head.lm_heads]
 | 
			
		||||
        for lh in lm_head_list:
 | 
			
		||||
            lm_head_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            zeros = []
 | 
			
		||||
            for l in lh:
 | 
			
		||||
                lm_head_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(lm_head_weights, axis=0),
 | 
			
		||||
                            torch.stack(scales, axis=0)))
 | 
			
		||||
                if l.zero is not None:
 | 
			
		||||
                    zeros.append(l.zero)
 | 
			
		||||
            if len(zeros):
 | 
			
		||||
                weights.append((torch.stack(lm_head_weights, axis=0),
 | 
			
		||||
                                torch.stack(scales, axis=0),
 | 
			
		||||
                                torch.stack(zeros, axis=0)))
 | 
			
		||||
            else:
 | 
			
		||||
                weights.append((torch.stack(lm_head_weights, axis=0),
 | 
			
		||||
                                torch.stack(scales, axis=0)))
 | 
			
		||||
    if isinstance(weights[0], tuple):
 | 
			
		||||
        np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
 | 
			
		||||
    else:  # FP16 Linear
 | 
			
		||||
| 
						 | 
				
			
			@ -202,7 +226,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
 | 
			
		|||
        dtype=np_dtype,
 | 
			
		||||
        model_norm_weight=model_norm.weight.to(torch.float16),
 | 
			
		||||
        vocab_size=vocab_size,
 | 
			
		||||
        n_splits=n_splits_linear
 | 
			
		||||
        n_splits=n_splits_linear,
 | 
			
		||||
        asym=asym
 | 
			
		||||
    )
 | 
			
		||||
    last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir,
 | 
			
		||||
                                                        True, True)
 | 
			
		||||
| 
						 | 
				
			
			@ -210,12 +235,24 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
 | 
			
		|||
    # save weights bins files
 | 
			
		||||
    if n_splits_linear == 1:
 | 
			
		||||
        if vocab_size == 122753:
 | 
			
		||||
            weight_numpy = [model.lm_head_0.weight.data.numpy(),
 | 
			
		||||
                            model.lm_head_0.scale.data.numpy(),
 | 
			
		||||
                            model.lm_head_1.weight.data.numpy(),
 | 
			
		||||
                            model.lm_head_1.scale.data.numpy(), ]
 | 
			
		||||
            if not asym:
 | 
			
		||||
                weight_numpy = [model.lm_head_0.weight.data.numpy(),
 | 
			
		||||
                                model.lm_head_0.scale.data.numpy(),
 | 
			
		||||
                                model.lm_head_1.weight.data.numpy(),
 | 
			
		||||
                                model.lm_head_1.scale.data.numpy(), ]
 | 
			
		||||
            else:
 | 
			
		||||
                weight_numpy = [model.lm_head_0.weight.data.numpy(),
 | 
			
		||||
                                model.lm_head_0.scale.data.numpy(),
 | 
			
		||||
                                model.lm_head_0.zero.data.numpy(),
 | 
			
		||||
                                model.lm_head_1.weight.data.numpy(),
 | 
			
		||||
                                model.lm_head_1.scale.data.numpy(),
 | 
			
		||||
                                model.lm_head_1.zero.data.numpy(), ]
 | 
			
		||||
        else:
 | 
			
		||||
            weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy(), ]
 | 
			
		||||
            if not asym:
 | 
			
		||||
                weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy()]
 | 
			
		||||
            else:
 | 
			
		||||
                weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy(),
 | 
			
		||||
                                model.lm_head.zero.data.numpy()]
 | 
			
		||||
    else:
 | 
			
		||||
        weight_numpy = [v.numpy() for v in weights[0]]
 | 
			
		||||
        if vocab_size == 122753:
 | 
			
		||||
| 
						 | 
				
			
			@ -266,6 +303,7 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    num_hidden_layers = model.config.num_hidden_layers
 | 
			
		||||
    scale_depth = model.model.config.scale_depth
 | 
			
		||||
    asym = getattr(model.config, "asym", False)
 | 
			
		||||
 | 
			
		||||
    from ipex_llm.transformers.npu_models.minicpm_mp import LowBitMinicpmMultiDecoderlayer
 | 
			
		||||
    curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
| 
						 | 
				
			
			@ -279,10 +317,17 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
                       mlp_layer.down_proj_dq_list]:
 | 
			
		||||
        l_weights = []
 | 
			
		||||
        scales = []
 | 
			
		||||
        zeros = []
 | 
			
		||||
        for l in layer_list:
 | 
			
		||||
            l_weights.append(l.weight)
 | 
			
		||||
            scales.append(l.scale)
 | 
			
		||||
        weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
            if l.zero is not None:
 | 
			
		||||
                zeros.append(l.zero)
 | 
			
		||||
        if len(zeros):
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
 | 
			
		||||
                            torch.stack(zeros, axis=0)))
 | 
			
		||||
        else:
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
    cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
    cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -321,7 +366,8 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
        dtype=np_dtype,
 | 
			
		||||
        n_splits_linear=n_splits_linear,
 | 
			
		||||
        n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
        group_size=group_size
 | 
			
		||||
        group_size=group_size,
 | 
			
		||||
        asym=asym
 | 
			
		||||
    )
 | 
			
		||||
    rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
 | 
			
		||||
                                                        decoder_name,
 | 
			
		||||
| 
						 | 
				
			
			@ -337,11 +383,23 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
            layer_norm_0.data.numpy().tofile(input_lm_bin_file)
 | 
			
		||||
            layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
			
		||||
            st_idx = 7
 | 
			
		||||
        for idx, (weight, scale) in enumerate(weights):
 | 
			
		||||
            bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
			
		||||
            weight.numpy().tofile(bin_file)
 | 
			
		||||
            bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
 | 
			
		||||
            scale.numpy().tofile(bin_file)
 | 
			
		||||
        if not asym:
 | 
			
		||||
            for idx, (weight, scale) in enumerate(weights):
 | 
			
		||||
                bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
			
		||||
                weight.numpy().tofile(bin_file)
 | 
			
		||||
                bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
 | 
			
		||||
                scale.numpy().tofile(bin_file)
 | 
			
		||||
        else:
 | 
			
		||||
            for idx, (weight, scale, zero) in enumerate(weights):
 | 
			
		||||
                bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*3}.bin")
 | 
			
		||||
                weight.numpy().tofile(bin_file)
 | 
			
		||||
                bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                        f"model_{layer_idx}_input_{st_idx+idx*3+1}.bin")
 | 
			
		||||
                scale.numpy().tofile(bin_file)
 | 
			
		||||
                bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                        f"model_{layer_idx}_input_{st_idx+idx*3+2}.bin")
 | 
			
		||||
                zero.numpy().tofile(bin_file)
 | 
			
		||||
 | 
			
		||||
        del single_decoder
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -357,6 +415,7 @@ def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_d
 | 
			
		|||
    scale_depth = model.model.config.scale_depth
 | 
			
		||||
    layer_num = len(model.model.layers)
 | 
			
		||||
    fused_layer_num = layer_num // fused_layers
 | 
			
		||||
    asym = getattr(model.config, "asym", False)
 | 
			
		||||
 | 
			
		||||
    from ipex_llm.transformers.npu_models.minicpm_mp import LowBitMinicpmMultiDecoderlayer
 | 
			
		||||
    for i in range(fused_layers):
 | 
			
		||||
| 
						 | 
				
			
			@ -380,10 +439,17 @@ def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_d
 | 
			
		|||
                               mlp_layer.down_proj_dq_list]:
 | 
			
		||||
                l_weights = []
 | 
			
		||||
                scales = []
 | 
			
		||||
                zeros = []
 | 
			
		||||
                for l in layer_list:
 | 
			
		||||
                    l_weights.append(l.weight)
 | 
			
		||||
                    scales.append(l.scale)
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
                    if l.zero is not None:
 | 
			
		||||
                        zeros.append(l.zero)
 | 
			
		||||
                if len(zeros):
 | 
			
		||||
                    weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
 | 
			
		||||
                                    torch.stack(zeros, axis=0)))
 | 
			
		||||
                else:
 | 
			
		||||
                    weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
            cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
            cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -401,12 +467,25 @@ def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_d
 | 
			
		|||
            layer_norm_1.data.numpy().tofile(post_lm_bin_file)
 | 
			
		||||
            st_idx = 5
 | 
			
		||||
            # 6, 7 are past k/v
 | 
			
		||||
            for idx, (weight, scale) in enumerate(weights):
 | 
			
		||||
                bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
			
		||||
                weight.numpy().tofile(bin_file)
 | 
			
		||||
                bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                        f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
 | 
			
		||||
                scale.numpy().tofile(bin_file)
 | 
			
		||||
            if not asym:
 | 
			
		||||
                for idx, (weight, scale) in enumerate(weights):
 | 
			
		||||
                    bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                            f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
 | 
			
		||||
                    weight.numpy().tofile(bin_file)
 | 
			
		||||
                    bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                            f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
 | 
			
		||||
                    scale.numpy().tofile(bin_file)
 | 
			
		||||
            else:
 | 
			
		||||
                for idx, (weight, scale, zero) in enumerate(weights):
 | 
			
		||||
                    bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                            f"model_{layer_idx}_input_{st_idx+idx*3}.bin")
 | 
			
		||||
                    weight.numpy().tofile(bin_file)
 | 
			
		||||
                    bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                            f"model_{layer_idx}_input_{st_idx+idx*3+1}.bin")
 | 
			
		||||
                    scale.numpy().tofile(bin_file)
 | 
			
		||||
                    bin_file = os.path.join(weight_dir,
 | 
			
		||||
                                            f"model_{layer_idx}_input_{st_idx+idx*3+2}.bin")
 | 
			
		||||
                    zero.numpy().tofile(bin_file)
 | 
			
		||||
 | 
			
		||||
        if isinstance(weights[0], tuple):
 | 
			
		||||
            np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
 | 
			
		||||
| 
						 | 
				
			
			@ -432,7 +511,8 @@ def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_d
 | 
			
		|||
            dtype=np_dtype,
 | 
			
		||||
            n_splits_linear=n_splits_linear,
 | 
			
		||||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size
 | 
			
		||||
            group_size=group_size,
 | 
			
		||||
            asym=asym
 | 
			
		||||
        )
 | 
			
		||||
        update_names_of_IR_and_export_blob(fused_decoder,
 | 
			
		||||
                                           f"decoder_layer_{i}",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue