[NPU] Llama3, Qwen2 1.5b, MiniCPM 1/2B groupwise support (#12327)
* support minicpm 1b & qwen 1.5b gw * support minicpm 1b * support minicpm 2b * fix style & error * fix style & update * remove print
This commit is contained in:
		
							parent
							
								
									82a61b5cf3
								
							
						
					
					
						commit
						d872639395
					
				
					 9 changed files with 239 additions and 68 deletions
				
			
		| 
						 | 
				
			
			@ -47,6 +47,7 @@ if __name__ == "__main__":
 | 
			
		|||
    parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
 | 
			
		||||
    parser.add_argument("--max-context-len", type=int, default=1024)
 | 
			
		||||
    parser.add_argument("--max-prompt-len", type=int, default=512)
 | 
			
		||||
    parser.add_argument("--quantization_group_size", type=int, default=0)
 | 
			
		||||
    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
			
		||||
    parser.add_argument("--disable-streaming", action="store_true", default=False)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -61,6 +62,7 @@ if __name__ == "__main__":
 | 
			
		|||
                                                     max_prompt_len=args.max_prompt_len,
 | 
			
		||||
                                                     torch_dtype=torch.float16,
 | 
			
		||||
                                                     attn_implementation="eager",
 | 
			
		||||
                                                     quantization_group_size=args.quantization_group_size,
 | 
			
		||||
                                                     transpose_value_cache=not args.disable_transpose_value_cache,
 | 
			
		||||
                                                     trust_remote_code=True)
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -76,13 +76,19 @@ def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down
 | 
			
		|||
    from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
 | 
			
		||||
    attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"]
 | 
			
		||||
    mlp_module_names = ["down_proj", "up_proj", "gate_proj"]
 | 
			
		||||
    if isinstance(module, (Qwen2Attention, LlamaAttention)):
 | 
			
		||||
    if (
 | 
			
		||||
        isinstance(module, (Qwen2Attention, LlamaAttention))
 | 
			
		||||
        or module.__class__.__name__ in ['MiniCPMAttention', 'Attention']
 | 
			
		||||
    ):
 | 
			
		||||
        for name in attn_module_names:
 | 
			
		||||
            setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
 | 
			
		||||
                                                            n_splits=n_splits_hidden_size,
 | 
			
		||||
                                                            load=load))
 | 
			
		||||
            delattr(module, name)
 | 
			
		||||
    elif isinstance(module, (Qwen2MLP, LlamaMLP)):
 | 
			
		||||
    elif (
 | 
			
		||||
        isinstance(module, (Qwen2MLP, LlamaMLP))
 | 
			
		||||
        or module.__class__.__name__ in ['MiniCPMMLP', 'MLP']
 | 
			
		||||
    ):
 | 
			
		||||
        for name in mlp_module_names:
 | 
			
		||||
            n_splits_mlp = n_splits_hidden_size
 | 
			
		||||
            if name == 'down_proj':
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -87,9 +87,8 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
 | 
			
		|||
            model.llm.config.model_type = "llama"
 | 
			
		||||
        model = model.llm
 | 
			
		||||
 | 
			
		||||
    if model.config.model_type in ["qwen2", "llama"]:
 | 
			
		||||
    if model.config.model_type in ["qwen2", "llama", "minicpm"]:
 | 
			
		||||
        from ipex_llm.transformers.npu_models.common import split_linears
 | 
			
		||||
 | 
			
		||||
        if quantization_group_size == 0:
 | 
			
		||||
            n_splits_linear = 1
 | 
			
		||||
            n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1
 | 
			
		||||
| 
						 | 
				
			
			@ -110,10 +109,21 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
 | 
			
		|||
 | 
			
		||||
        if quantization_group_size != 0:
 | 
			
		||||
            split_num = model.config.hidden_size // quantization_group_size
 | 
			
		||||
            new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
 | 
			
		||||
                                       bias=model.lm_head.bias, use_split=True)
 | 
			
		||||
            del model.lm_head
 | 
			
		||||
            model.lm_head = new_lm_head
 | 
			
		||||
            if model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40:
 | 
			
		||||
                # workaround for MiniCPM-2B
 | 
			
		||||
                new_lm_head_0 = SlicedLMHead(model.lm_head_0.weight, split_num=split_num,
 | 
			
		||||
                                             bias=model.lm_head_0.bias, use_split=True)
 | 
			
		||||
                del model.lm_head_0
 | 
			
		||||
                model.lm_head_0 = new_lm_head_0
 | 
			
		||||
                new_lm_head_1 = SlicedLMHead(model.lm_head_1.weight, split_num=split_num,
 | 
			
		||||
                                             bias=model.lm_head_1.bias, use_split=True)
 | 
			
		||||
                del model.lm_head_1
 | 
			
		||||
                model.lm_head_1 = new_lm_head_1
 | 
			
		||||
            else:
 | 
			
		||||
                new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
 | 
			
		||||
                                           bias=model.lm_head.bias, use_split=True)
 | 
			
		||||
                del model.lm_head
 | 
			
		||||
                model.lm_head = new_lm_head
 | 
			
		||||
 | 
			
		||||
    if model.config.model_type == "qwen2":
 | 
			
		||||
        # for Qwen2-7B-Insturct, divide lm_head into 14 parts
 | 
			
		||||
| 
						 | 
				
			
			@ -372,6 +382,10 @@ def optimize_llm(
 | 
			
		|||
                         transpose_value_cache=transpose_value_cache)
 | 
			
		||||
    if hasattr(model, 'lm_head') and isinstance(model.lm_head, SlicedLMHead):
 | 
			
		||||
        model.lm_head.get_fused_lm_head()
 | 
			
		||||
    # MiniCPM-2b
 | 
			
		||||
    if hasattr(model, "lm_head_1") and isinstance(model.lm_head_1, SlicedLMHead):
 | 
			
		||||
        model.lm_head_1.get_fused_lm_head()
 | 
			
		||||
        model.lm_head_0.get_fused_lm_head()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def optimize_funasr(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -110,8 +110,8 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        # define input, the order self.parameter matters
 | 
			
		||||
        input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
 | 
			
		||||
 | 
			
		||||
        # llama2 use ov sdp, other models need to test
 | 
			
		||||
        use_prefill_sdp = self.intermediate_size == 11008
 | 
			
		||||
        # llama2/3 use ov sdp, other models need to test
 | 
			
		||||
        use_prefill_sdp = self.intermediate_size in [11008, 14336]
 | 
			
		||||
 | 
			
		||||
        # Self Attention
 | 
			
		||||
        if mode == "decode":
 | 
			
		||||
| 
						 | 
				
			
			@ -437,7 +437,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        )
 | 
			
		||||
        self.layer_norm_0 = layer_norm_0
 | 
			
		||||
        self.layer_norm_1 = layer_norm_1
 | 
			
		||||
        self.use_prefill_sdp = intermediate_size == 11008
 | 
			
		||||
        self.use_prefill_sdp = intermediate_size in [11008, 14336]
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -78,13 +78,19 @@ class LowBitMinicpmMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        rms_norm_eps,
 | 
			
		||||
        intermediate_size,
 | 
			
		||||
        scale_depth,
 | 
			
		||||
        num_hidden_layers
 | 
			
		||||
        num_hidden_layers,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(max_seq_len=max_seq_len,
 | 
			
		||||
                         transpose_value=transpose_value,
 | 
			
		||||
                         dtype=dtype,
 | 
			
		||||
                         profile=profile,
 | 
			
		||||
                         device=device)
 | 
			
		||||
                         device=device,
 | 
			
		||||
                         n_splits_linear=n_splits_linear,
 | 
			
		||||
                         n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
                         group_size=group_size)
 | 
			
		||||
        self.max_seq_len = max_seq_len
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.dtype = dtype
 | 
			
		||||
| 
						 | 
				
			
			@ -235,7 +241,7 @@ class LowBitMinicpmMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
                                         attn_output * layer_scale_depth)
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
        hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
 | 
			
		||||
        hidden_states = self.mlp(hidden_states)
 | 
			
		||||
        hidden_states = self.mlp(hidden_states, self.seq_len, self.mode)
 | 
			
		||||
        hidden_states = self.eltwise_add(residual,
 | 
			
		||||
                                         hidden_states * layer_scale_depth)
 | 
			
		||||
        hidden_states = self.convert_to_fp16(hidden_states)
 | 
			
		||||
| 
						 | 
				
			
			@ -264,6 +270,9 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        max_seq_len: int = 1024,
 | 
			
		||||
        transpose_value: bool = False,
 | 
			
		||||
        do_print: bool = False,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -273,6 +282,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        for w in parameters:
 | 
			
		||||
            if isinstance(w, tuple):  # from QuantizedLinear
 | 
			
		||||
                op_parameters.append((w[0].numpy(), w[1].numpy()))
 | 
			
		||||
            elif w.dtype in [torch.int8, torch.uint8]:    # QuantizedLinear weight
 | 
			
		||||
                op_parameters.append(w.numpy())
 | 
			
		||||
            elif isinstance(w, np.ndarray):     # scale
 | 
			
		||||
                op_parameters.append(w)
 | 
			
		||||
            else:
 | 
			
		||||
                op_parameters.append(w.to(torch.float16).numpy())
 | 
			
		||||
        self.op_parameters = op_parameters
 | 
			
		||||
| 
						 | 
				
			
			@ -281,6 +294,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        self.transpose_value = transpose_value
 | 
			
		||||
        if isinstance(parameters[0], tuple):
 | 
			
		||||
            np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
 | 
			
		||||
        elif parameters[0].dtype == torch.int8:
 | 
			
		||||
            np_dtype = np.int8
 | 
			
		||||
        elif parameters[0].dtype == torch.uint8:
 | 
			
		||||
            np_dtype = np.uint8
 | 
			
		||||
        else:  # FP16 Linear
 | 
			
		||||
            np_dtype = np.float16
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -317,6 +334,9 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
                mode="decode",
 | 
			
		||||
                transpose_value=self.transpose_value,
 | 
			
		||||
                dtype=np_dtype,
 | 
			
		||||
                n_splits_linear=n_splits_linear,
 | 
			
		||||
                n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
                group_size=group_size
 | 
			
		||||
            )
 | 
			
		||||
            self.backend_decoders.append(decoder)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -392,6 +412,9 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        num_hidden_layers,
 | 
			
		||||
        max_seq_len: int = 128,
 | 
			
		||||
        transpose_value: bool = False,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.op_parameters = parameters
 | 
			
		||||
| 
						 | 
				
			
			@ -422,6 +445,9 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
            mode="prefill",
 | 
			
		||||
            transpose_value=self.transpose_value,
 | 
			
		||||
            dtype=np_dtype,
 | 
			
		||||
            n_splits_linear=n_splits_linear,
 | 
			
		||||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size
 | 
			
		||||
        )
 | 
			
		||||
        self.layer_norm_0 = layer_norm_0
 | 
			
		||||
        self.layer_norm_1 = layer_norm_1
 | 
			
		||||
| 
						 | 
				
			
			@ -501,24 +527,53 @@ def run_decode(
 | 
			
		|||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    intermediate_size = model.config.intermediate_size
 | 
			
		||||
    num_hidden_layers = model.config.num_hidden_layers
 | 
			
		||||
    group_size = getattr(model.config, "group_size", 0)
 | 
			
		||||
    layer_weights = []
 | 
			
		||||
    input_layer_norm_weights = []
 | 
			
		||||
    post_attn_layernorm_weights = []
 | 
			
		||||
    layer_indexs = range(layer_start, layer_end)
 | 
			
		||||
    n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
			
		||||
    n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
			
		||||
    for layer_idx in layer_indexs:
 | 
			
		||||
        curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
        attn_layer = curr_layer.self_attn
 | 
			
		||||
        mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
        weights = [
 | 
			
		||||
            (attn_layer.q_proj.weight, attn_layer.q_proj.scale),
 | 
			
		||||
            (attn_layer.k_proj.weight, attn_layer.k_proj.scale),
 | 
			
		||||
            (attn_layer.v_proj.weight, attn_layer.v_proj.scale),
 | 
			
		||||
            (attn_layer.o_proj.weight, attn_layer.o_proj.scale),
 | 
			
		||||
            (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
 | 
			
		||||
            (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
			
		||||
            (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
 | 
			
		||||
        ]
 | 
			
		||||
        weights = []
 | 
			
		||||
        if n_splits_linear == 1:
 | 
			
		||||
            for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                        attn_layer.k_proj_dq_list,
 | 
			
		||||
                                        attn_layer.v_proj_dq_list,
 | 
			
		||||
                                        attn_layer.o_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.up_proj_dq_list):
 | 
			
		||||
                weights.append((q.weight, q.scale))
 | 
			
		||||
                weights.append((k.weight, k.scale))
 | 
			
		||||
                weights.append((v.weight, v.scale))
 | 
			
		||||
                weights.append((o.weight, o.scale))
 | 
			
		||||
                weights.append((g.weight, g.scale))
 | 
			
		||||
                weights.append((u.weight, u.scale))
 | 
			
		||||
        else:
 | 
			
		||||
            for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                               attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                               mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                l_weights = []
 | 
			
		||||
                scales = []
 | 
			
		||||
                for l in layer_list:
 | 
			
		||||
                    l_weights.append(l.weight)
 | 
			
		||||
                    scales.append(l.scale)
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        if n_splits_down_proj == 1:
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                weights.append((l.weight, l.scale))
 | 
			
		||||
        else:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -547,6 +602,9 @@ def run_decode(
 | 
			
		|||
        max_seq_len=max_seq_len,
 | 
			
		||||
        transpose_value=transpose_value_cache,
 | 
			
		||||
        do_print=False,
 | 
			
		||||
        n_splits_linear=n_splits_linear,
 | 
			
		||||
        n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
        group_size=group_size
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    dist.barrier()
 | 
			
		||||
| 
						 | 
				
			
			@ -711,25 +769,55 @@ def run_prefill(
 | 
			
		|||
    intermediate_size = model.config.intermediate_size
 | 
			
		||||
    scale_depth = model.config.scale_depth
 | 
			
		||||
    num_hidden_layers = model.config.num_hidden_layers
 | 
			
		||||
    group_size = getattr(model.config, "group_size", 0)
 | 
			
		||||
    deocderlayers = []
 | 
			
		||||
    layer_weights = []
 | 
			
		||||
    input_layer_norm_weights = []
 | 
			
		||||
    post_attn_layernorm_weights = []
 | 
			
		||||
    layer_indexs = range(layer_start, layer_end)
 | 
			
		||||
    n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
			
		||||
    n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
			
		||||
    for layer_idx in layer_indexs:
 | 
			
		||||
        curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
        attn_layer = curr_layer.self_attn
 | 
			
		||||
        mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
        weights = [
 | 
			
		||||
            (attn_layer.q_proj.weight, attn_layer.q_proj.scale),
 | 
			
		||||
            (attn_layer.k_proj.weight, attn_layer.k_proj.scale),
 | 
			
		||||
            (attn_layer.v_proj.weight, attn_layer.v_proj.scale),
 | 
			
		||||
            (attn_layer.o_proj.weight, attn_layer.o_proj.scale),
 | 
			
		||||
            (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
 | 
			
		||||
            (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
			
		||||
            (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
 | 
			
		||||
        ]
 | 
			
		||||
        weights = []
 | 
			
		||||
 | 
			
		||||
        if n_splits_linear == 1:
 | 
			
		||||
            for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                        attn_layer.k_proj_dq_list,
 | 
			
		||||
                                        attn_layer.v_proj_dq_list,
 | 
			
		||||
                                        attn_layer.o_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.up_proj_dq_list):
 | 
			
		||||
                weights.append((q.weight, q.scale))
 | 
			
		||||
                weights.append((k.weight, k.scale))
 | 
			
		||||
                weights.append((v.weight, v.scale))
 | 
			
		||||
                weights.append((o.weight, o.scale))
 | 
			
		||||
                weights.append((g.weight, g.scale))
 | 
			
		||||
                weights.append((u.weight, u.scale))
 | 
			
		||||
        else:
 | 
			
		||||
            for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                               attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                               mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                l_weights = []
 | 
			
		||||
                scales = []
 | 
			
		||||
                for l in layer_list:
 | 
			
		||||
                    l_weights.append(l.weight)
 | 
			
		||||
                    scales.append(l.scale)
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        if n_splits_down_proj == 1:
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                weights.append((l.weight, l.scale))
 | 
			
		||||
        else:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -752,6 +840,9 @@ def run_prefill(
 | 
			
		|||
            num_hidden_layers=num_hidden_layers,
 | 
			
		||||
            max_seq_len=max_output_len,
 | 
			
		||||
            transpose_value=transpose_value_cache,
 | 
			
		||||
            n_splits_linear=n_splits_linear,
 | 
			
		||||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        layer_weights.extend(weights)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -273,7 +273,6 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
                                               self.n_splits_linear, wt_dtype=self.dtype,
 | 
			
		||||
                                               scale_factor=(self.group_size == 0),
 | 
			
		||||
                                               is_prefill=(mode == "prefill"))
 | 
			
		||||
 | 
			
		||||
        return attn_output, new_key_states, new_value_states
 | 
			
		||||
 | 
			
		||||
    def paraformer_layer_norm(self, hidden_states, layernorm_weight, layernorm_bias):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -370,6 +370,9 @@ def convert_llm(model: torch.nn.Module,
 | 
			
		|||
 | 
			
		||||
    if hasattr(model, "lm_head") and isinstance(model.lm_head, SlicedLMHead):
 | 
			
		||||
        model.lm_head.get_fused_lm_head()
 | 
			
		||||
    if hasattr(model, "lm_head_1") and isinstance(model.lm_head_1, SlicedLMHead):
 | 
			
		||||
        model.lm_head_1.get_fused_lm_head()
 | 
			
		||||
        model.lm_head_0.get_fused_lm_head()
 | 
			
		||||
 | 
			
		||||
    # patch generate function
 | 
			
		||||
    import types
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -81,6 +81,7 @@ class MiniCPMLMHead(LLMBaseNNFactory):
 | 
			
		|||
        transpose_value: bool = False,
 | 
			
		||||
        profile: bool = False,
 | 
			
		||||
        device: str = "NPU",
 | 
			
		||||
        n_splits: int = 1,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(max_seq_len=max_seq_len,
 | 
			
		||||
                         transpose_value=transpose_value,
 | 
			
		||||
| 
						 | 
				
			
			@ -108,19 +109,37 @@ class MiniCPMLMHead(LLMBaseNNFactory):
 | 
			
		|||
        hidden_states = self.layer_norm(hidden_states, model_norm_weight)
 | 
			
		||||
        if vocab_size == 122753:
 | 
			
		||||
            # for MiniCPM-2B-sft-bf16
 | 
			
		||||
            hidden_states_1 = self.linear(
 | 
			
		||||
                hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
			
		||||
            )
 | 
			
		||||
            hidden_states_2 = self.linear(
 | 
			
		||||
                hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
			
		||||
            )
 | 
			
		||||
            if n_splits == 1:
 | 
			
		||||
                hidden_states_1 = self.linear(
 | 
			
		||||
                    hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
			
		||||
                )
 | 
			
		||||
                hidden_states_2 = self.linear(
 | 
			
		||||
                    hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                hidden_states_1 = self.dq_split_linear(
 | 
			
		||||
                    hidden_states, 73440, self.hidden_size,
 | 
			
		||||
                    n_splits=n_splits, wt_dtype=dtype, scale_factor=False
 | 
			
		||||
                )
 | 
			
		||||
                hidden_states_2 = self.dq_split_linear(
 | 
			
		||||
                    hidden_states, 73440, self.hidden_size,
 | 
			
		||||
                    n_splits=n_splits, wt_dtype=dtype, scale_factor=False
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            hidden_states_2 = self.slice(hidden_states_2, begin=[0, 0, 0], end=[1, 1, 49313])
 | 
			
		||||
            hidden_states = self.concat(hidden_states_1, hidden_states_2, axis=2)
 | 
			
		||||
        else:
 | 
			
		||||
            # for MiniCPM-1B-sft-bf16
 | 
			
		||||
            hidden_states = self.linear(
 | 
			
		||||
                hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
			
		||||
            )
 | 
			
		||||
            if n_splits == 1:
 | 
			
		||||
                hidden_states = self.linear(
 | 
			
		||||
                    hidden_states, self.vocab_size, self.hidden_size, bias=False,
 | 
			
		||||
                    wt_dtype=self.dtype
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                hidden_states = self.dq_split_linear(
 | 
			
		||||
                    hidden_states, self.vocab_size, self.hidden_size,
 | 
			
		||||
                    n_splits=n_splits, wt_dtype=dtype, scale_factor=False
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        # define outputs
 | 
			
		||||
        hidden_states = self.convert_to_fp32(hidden_states)
 | 
			
		||||
| 
						 | 
				
			
			@ -145,8 +164,19 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
			
		|||
            # for MiniCPM-1B-sft-bf16
 | 
			
		||||
            weights = [(model.lm_head.weight, model.lm_head.scale)]
 | 
			
		||||
    else:
 | 
			
		||||
        # TODO
 | 
			
		||||
        pass
 | 
			
		||||
        weights = []
 | 
			
		||||
        if vocab_size == 122753:
 | 
			
		||||
            lm_head_list = [model.lm_head_0.lm_heads, model.lm_head_1.lm_heads]
 | 
			
		||||
        else:
 | 
			
		||||
            lm_head_list = [model.lm_head.lm_heads]
 | 
			
		||||
        for lh in lm_head_list:
 | 
			
		||||
            lm_head_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in lh:
 | 
			
		||||
                lm_head_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(lm_head_weights, axis=0),
 | 
			
		||||
                            torch.stack(scales, axis=0)))
 | 
			
		||||
    if isinstance(weights[0], tuple):
 | 
			
		||||
        np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
 | 
			
		||||
    else:  # FP16 Linear
 | 
			
		||||
| 
						 | 
				
			
			@ -162,6 +192,7 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
			
		|||
        dtype=np_dtype,
 | 
			
		||||
        model_norm_weight=model_norm.weight.to(torch.float16),
 | 
			
		||||
        vocab_size=vocab_size,
 | 
			
		||||
        n_splits=n_splits_linear
 | 
			
		||||
    )
 | 
			
		||||
    last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -175,8 +206,9 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
			
		|||
        else:
 | 
			
		||||
            weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy(), ]
 | 
			
		||||
    else:
 | 
			
		||||
        # TODO
 | 
			
		||||
        pass
 | 
			
		||||
        weight_numpy = [v.numpy() for v in weights[0]]
 | 
			
		||||
        if vocab_size == 122753:
 | 
			
		||||
            weight_numpy.extend([v.numpy() for v in weights[1]])
 | 
			
		||||
 | 
			
		||||
    for idx, weight in enumerate(weight_numpy):
 | 
			
		||||
        bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
 | 
			
		||||
| 
						 | 
				
			
			@ -214,18 +246,40 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
 | 
			
		||||
    weights = []
 | 
			
		||||
    if n_splits_linear == 1:
 | 
			
		||||
        weights = [
 | 
			
		||||
            (attn_layer.q_proj.weight, attn_layer.q_proj.scale),
 | 
			
		||||
            (attn_layer.k_proj.weight, attn_layer.k_proj.scale),
 | 
			
		||||
            (attn_layer.v_proj.weight, attn_layer.v_proj.scale),
 | 
			
		||||
            (attn_layer.o_proj.weight, attn_layer.o_proj.scale),
 | 
			
		||||
            (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
 | 
			
		||||
            (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
			
		||||
            (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
 | 
			
		||||
        ]
 | 
			
		||||
        for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                    attn_layer.k_proj_dq_list,
 | 
			
		||||
                                    attn_layer.v_proj_dq_list,
 | 
			
		||||
                                    attn_layer.o_proj_dq_list,
 | 
			
		||||
                                    mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                    mlp_layer.up_proj_dq_list):
 | 
			
		||||
            weights.append((q.weight, q.scale))
 | 
			
		||||
            weights.append((k.weight, k.scale))
 | 
			
		||||
            weights.append((v.weight, v.scale))
 | 
			
		||||
            weights.append((o.weight, o.scale))
 | 
			
		||||
            weights.append((g.weight, g.scale))
 | 
			
		||||
            weights.append((u.weight, u.scale))
 | 
			
		||||
    else:
 | 
			
		||||
        # TODO
 | 
			
		||||
        pass
 | 
			
		||||
        for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                           attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                           mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in layer_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0),
 | 
			
		||||
                            torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
    if n_splits_down_proj == 1:
 | 
			
		||||
        for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
            weights.append((l.weight, l.scale))
 | 
			
		||||
    else:
 | 
			
		||||
        l_weights = []
 | 
			
		||||
        scales = []
 | 
			
		||||
        for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
            l_weights.append(l.weight)
 | 
			
		||||
            scales.append(l.scale)
 | 
			
		||||
        weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
    cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
    cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -254,6 +308,9 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
        mode="decode",
 | 
			
		||||
        transpose_value=transpose_value_cache,
 | 
			
		||||
        dtype=np_dtype,
 | 
			
		||||
        n_splits_linear=n_splits_linear,
 | 
			
		||||
        n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
        group_size=group_size
 | 
			
		||||
    )
 | 
			
		||||
    rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
 | 
			
		||||
                                                        f"decoder_layer_{layer_idx}",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,6 +19,7 @@ import torch
 | 
			
		|||
import numpy as np
 | 
			
		||||
import os
 | 
			
		||||
from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead
 | 
			
		||||
from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
			
		||||
| 
						 | 
				
			
			@ -27,18 +28,16 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
			
		|||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    vocab_size = model.config.vocab_size
 | 
			
		||||
    model_norm = model.model.norm
 | 
			
		||||
    if model.config.intermediate_size == 18944:
 | 
			
		||||
        lm_heads = model.lm_head.lm_heads  # Qwen2-7B is always SlicedLMHead
 | 
			
		||||
    else:
 | 
			
		||||
        lm_heads = [model.lm_head]
 | 
			
		||||
    if n_splits_linear == 1:
 | 
			
		||||
        weights = [(lm_heads[0].weight, lm_heads[0].scale)]
 | 
			
		||||
    lm_head = model.lm_head
 | 
			
		||||
    if not isinstance(lm_head, SlicedLMHead):
 | 
			
		||||
        weights = [(lm_head.weight, lm_head.scale)]
 | 
			
		||||
    else:
 | 
			
		||||
        lm_heads = lm_head.lm_heads
 | 
			
		||||
        lm_head_weights = []
 | 
			
		||||
        scales = []
 | 
			
		||||
        for i in range(n_splits_linear):
 | 
			
		||||
            lm_head_weights.append(lm_heads[i].weight)
 | 
			
		||||
            scales.append(lm_heads[i].scale)
 | 
			
		||||
        for l in lm_heads:
 | 
			
		||||
            lm_head_weights.append(l.weight)
 | 
			
		||||
            scales.append(l.scale)
 | 
			
		||||
        weights = [(torch.stack(lm_head_weights, axis=0),
 | 
			
		||||
                    torch.stack(scales, axis=0))]
 | 
			
		||||
    if isinstance(weights[0], tuple):
 | 
			
		||||
| 
						 | 
				
			
			@ -61,9 +60,9 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
			
		|||
    last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
 | 
			
		||||
 | 
			
		||||
    # save weights bins files
 | 
			
		||||
    if n_splits_linear == 1:
 | 
			
		||||
    if not isinstance(lm_head, SlicedLMHead):
 | 
			
		||||
        weight_numpy = [
 | 
			
		||||
            lm_heads[0].weight.data.numpy(), lm_heads[0].scale.data.numpy(),
 | 
			
		||||
            lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
 | 
			
		||||
        ]
 | 
			
		||||
    else:
 | 
			
		||||
        weight_numpy = [v.numpy() for v in weights[0]]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue