[NPU] Support Baichuan groupwise & gw code refactor (#12337)
* support minicpm 1b & qwen 1.5b gw * support minicpm 1b * baichuan part * update * support minicpm 1b & qwen 1.5b gw * support minicpm 1b * baichuan part * update * update * update * baichuan support * code refactor * remove code * fix style * address comments * revert
This commit is contained in:
		
							parent
							
								
									812d5cc32e
								
							
						
					
					
						commit
						b2e69a896c
					
				
					 13 changed files with 367 additions and 434 deletions
				
			
		| 
						 | 
				
			
			@ -60,6 +60,7 @@ if __name__ == "__main__":
 | 
			
		|||
    parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
 | 
			
		||||
    parser.add_argument("--max-context-len", type=int, default=1024)
 | 
			
		||||
    parser.add_argument("--max-prompt-len", type=int, default=512)
 | 
			
		||||
    parser.add_argument("--quantization_group_size", type=int, default=0)
 | 
			
		||||
    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
			
		||||
    parser.add_argument("--disable-streaming", action="store_true", default=False)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -72,6 +73,7 @@ if __name__ == "__main__":
 | 
			
		|||
                                                     pipeline=True,
 | 
			
		||||
                                                     max_context_len=args.max_context_len,
 | 
			
		||||
                                                     max_prompt_len=args.max_prompt_len,
 | 
			
		||||
                                                     quantization_group_size=args.quantization_group_size,
 | 
			
		||||
                                                     torch_dtype=torch.float16,
 | 
			
		||||
                                                     attn_implementation="eager",
 | 
			
		||||
                                                     transpose_value_cache=not args.disable_transpose_value_cache,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,6 +50,9 @@ from transformers.cache_utils import Cache
 | 
			
		|||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
from ipex_llm.transformers.npu_models.mp_models_base import run_model
 | 
			
		||||
from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
 | 
			
		||||
from ipex_llm.transformers.npu_models.common import reshape_lm_head_input
 | 
			
		||||
from transformers.modeling_outputs import CausalLMOutputWithPast
 | 
			
		||||
from torch.nn import CrossEntropyLoss
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		||||
| 
						 | 
				
			
			@ -75,12 +78,18 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        device: str = "NPU",
 | 
			
		||||
        rms_norm_eps,
 | 
			
		||||
        intermediate_size,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(max_seq_len=max_seq_len,
 | 
			
		||||
                         transpose_value=transpose_value,
 | 
			
		||||
                         dtype=dtype,
 | 
			
		||||
                         profile=profile,
 | 
			
		||||
                         device=device)
 | 
			
		||||
                         device=device,
 | 
			
		||||
                         n_splits_linear=n_splits_linear,
 | 
			
		||||
                         n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
                         group_size=group_size)
 | 
			
		||||
        self.max_seq_len = max_seq_len
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.dtype = dtype
 | 
			
		||||
| 
						 | 
				
			
			@ -115,8 +124,7 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
            attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
 | 
			
		||||
                                                  dtype=np.int64)
 | 
			
		||||
        else:
 | 
			
		||||
            attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len),
 | 
			
		||||
                                                  dtype=np.int64)
 | 
			
		||||
            attention_mask = None
 | 
			
		||||
 | 
			
		||||
        position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
 | 
			
		||||
        # self.num_key_value_heads = num_key_value_heads
 | 
			
		||||
| 
						 | 
				
			
			@ -178,6 +186,7 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
                post_attention_layernorm_weight=post_attn_layernorm_weights[i],
 | 
			
		||||
                past_key=past_keys[i],
 | 
			
		||||
                past_value=past_values[i],
 | 
			
		||||
                use_prefill_sdp=True,
 | 
			
		||||
            )
 | 
			
		||||
            curr_key_values.append((new_key_states, new_value_states))
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -189,6 +198,9 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
            new_value_states = self.convert_to_fp16(curr_key_values[i][1])
 | 
			
		||||
 | 
			
		||||
        print("start compiling")
 | 
			
		||||
        if mode == "prefill" and os.environ.get("IPEX_LLM_NPU_DISABLE_COMPILE_OPT", "0") != "1":
 | 
			
		||||
            self.compile(npu_dpu_groups=6)
 | 
			
		||||
        else:
 | 
			
		||||
            self.compile()
 | 
			
		||||
 | 
			
		||||
    def attention(self,
 | 
			
		||||
| 
						 | 
				
			
			@ -206,15 +218,23 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
                  seq_len,
 | 
			
		||||
                  q_bias=None,
 | 
			
		||||
                  k_bias=None,
 | 
			
		||||
                  v_bias=None):
 | 
			
		||||
                  v_bias=None,
 | 
			
		||||
                  use_prefill_sdp=False):
 | 
			
		||||
        hidden_size = num_heads * head_dim
 | 
			
		||||
        if self.n_splits_linear != 1:
 | 
			
		||||
            hidden_states = self.unsqueeze(hidden_states, axis=0)
 | 
			
		||||
 | 
			
		||||
        proj = self.linear(
 | 
			
		||||
            hidden_states,
 | 
			
		||||
            3 * hidden_size,
 | 
			
		||||
            hidden_size,
 | 
			
		||||
            bias=False,
 | 
			
		||||
            wt_dtype=self.dtype
 | 
			
		||||
            wt_dtype=self.dtype,
 | 
			
		||||
            n_splits=self.n_splits_linear,
 | 
			
		||||
            scale_factor=(self.group_size == 0),
 | 
			
		||||
            is_prefill=(mode == "prefill")
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        proj = self.reshape(proj, [-1, 3, hidden_size])  # b*s, 3, h
 | 
			
		||||
        proj = self.unsqueeze(proj, [0])  # b, s, 3, h
 | 
			
		||||
        proj = self.transpose(proj, [2, 1, 0, 3])  # 3, s, b, h
 | 
			
		||||
| 
						 | 
				
			
			@ -224,8 +244,14 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        key_states = self.reshape(proj[1, ...], [1, self.seq_len, num_heads, head_dim])
 | 
			
		||||
        key_states = self.transpose(key_states, [0, 2, 1, 3])
 | 
			
		||||
        value_states = self.reshape(proj[2, ...], [1, self.seq_len, num_heads, head_dim])
 | 
			
		||||
 | 
			
		||||
        use_ov_sdp = (mode == "prefill") and use_prefill_sdp
 | 
			
		||||
        if self.transpose_value:
 | 
			
		||||
            value_states = self.transpose(value_states, [0, 2, 3, 1])
 | 
			
		||||
            new_value_states = self.transpose(value_states, [0, 2, 3, 1])
 | 
			
		||||
            if use_ov_sdp:
 | 
			
		||||
                value_states = self.transpose(value_states, [0, 2, 1, 3])
 | 
			
		||||
            else:
 | 
			
		||||
                value_states = new_value_states
 | 
			
		||||
        else:
 | 
			
		||||
            value_states = self.transpose(value_states, [0, 2, 1, 3])
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -243,7 +269,6 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
            head_dim=head_dim,
 | 
			
		||||
        )
 | 
			
		||||
        new_key_states = key_states
 | 
			
		||||
        new_value_states = value_states
 | 
			
		||||
 | 
			
		||||
        if self.mode == "decode":
 | 
			
		||||
            key_states = self.concat(past_key, key_states, axis=-2)
 | 
			
		||||
| 
						 | 
				
			
			@ -252,6 +277,14 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
            else:
 | 
			
		||||
                value_states = self.concat(past_value, value_states, axis=-2)
 | 
			
		||||
 | 
			
		||||
        if use_ov_sdp:
 | 
			
		||||
            value_states = self.convert_to_fp32(value_states)
 | 
			
		||||
            key_states = self.convert_to_fp32(key_states)
 | 
			
		||||
            query_states = self.convert_to_fp32(query_states)
 | 
			
		||||
            attn_output = self.scaled_dot_product_attention(
 | 
			
		||||
                query_states, key_states, value_states, None, True)
 | 
			
		||||
            attn_output = self.convert_to_fp16(attn_output)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_weight = self.matmul(query_states, key_states, False, True) / (
 | 
			
		||||
                math.sqrt(self.head_dim))
 | 
			
		||||
            attention_mask = self.convert_to_fp16(attention_mask)
 | 
			
		||||
| 
						 | 
				
			
			@ -265,7 +298,10 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        attn_output = self.reshape(attn_output, [1, seq_len, hidden_size])
 | 
			
		||||
 | 
			
		||||
        attn_output = self.linear(
 | 
			
		||||
            attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
 | 
			
		||||
            attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype,
 | 
			
		||||
            n_splits=self.n_splits_linear,
 | 
			
		||||
            scale_factor=(self.group_size == 0),
 | 
			
		||||
            is_prefill=(mode == "prefill")
 | 
			
		||||
        )
 | 
			
		||||
        return attn_output, new_key_states, new_value_states
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -278,6 +314,7 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        post_attention_layernorm_weight,
 | 
			
		||||
        past_key=None,
 | 
			
		||||
        past_value=None,
 | 
			
		||||
        use_prefill_sdp=False,
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
| 
						 | 
				
			
			@ -298,12 +335,13 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
            num_heads=self.num_heads,
 | 
			
		||||
            head_dim=self.head_dim,
 | 
			
		||||
            seq_len=self.seq_len,
 | 
			
		||||
            use_prefill_sdp=use_prefill_sdp,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.eltwise_add(residual, attn_output)
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
        hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
 | 
			
		||||
        hidden_states = self.mlp(hidden_states)
 | 
			
		||||
        hidden_states = self.mlp(hidden_states, self.seq_len, self.mode)
 | 
			
		||||
        hidden_states = self.eltwise_add(residual, hidden_states)
 | 
			
		||||
        hidden_states = self.convert_to_fp16(hidden_states)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -329,6 +367,9 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        max_seq_len: int = 1024,
 | 
			
		||||
        transpose_value: bool = False,
 | 
			
		||||
        do_print: bool = False,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -338,6 +379,10 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        for w in parameters:
 | 
			
		||||
            if isinstance(w, tuple):  # from QuantizedLinear
 | 
			
		||||
                op_parameters.append((w[0].numpy(), w[1].numpy()))
 | 
			
		||||
            elif w.dtype in [torch.int8, torch.uint8]:    # QuantizedLinear weight
 | 
			
		||||
                op_parameters.append(w.numpy())
 | 
			
		||||
            elif isinstance(w, np.ndarray):     # scale
 | 
			
		||||
                op_parameters.append(w)
 | 
			
		||||
            else:
 | 
			
		||||
                op_parameters.append(w.to(torch.float16).numpy())
 | 
			
		||||
        self.op_parameters = op_parameters
 | 
			
		||||
| 
						 | 
				
			
			@ -346,6 +391,10 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        self.transpose_value = transpose_value
 | 
			
		||||
        if isinstance(parameters[0], tuple):
 | 
			
		||||
            np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
 | 
			
		||||
        elif parameters[0].dtype == torch.int8:
 | 
			
		||||
            np_dtype = np.int8
 | 
			
		||||
        elif parameters[0].dtype == torch.uint8:
 | 
			
		||||
            np_dtype = np.uint8
 | 
			
		||||
        else:  # FP16 Linear
 | 
			
		||||
            np_dtype = np.float16
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -380,6 +429,9 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
                mode="decode",
 | 
			
		||||
                transpose_value=self.transpose_value,
 | 
			
		||||
                dtype=np_dtype,
 | 
			
		||||
                n_splits_linear=n_splits_linear,
 | 
			
		||||
                n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
                group_size=group_size
 | 
			
		||||
            )
 | 
			
		||||
            self.backend_decoders.append(decoder)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -453,6 +505,9 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        intermediate_size,
 | 
			
		||||
        max_seq_len: int = 128,
 | 
			
		||||
        transpose_value: bool = False,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.op_parameters = parameters
 | 
			
		||||
| 
						 | 
				
			
			@ -481,6 +536,9 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
            mode="prefill",
 | 
			
		||||
            transpose_value=self.transpose_value,
 | 
			
		||||
            dtype=np_dtype,
 | 
			
		||||
            n_splits_linear=n_splits_linear,
 | 
			
		||||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size
 | 
			
		||||
        )
 | 
			
		||||
        self.layer_norm_0 = layer_norm_0
 | 
			
		||||
        self.layer_norm_1 = layer_norm_1
 | 
			
		||||
| 
						 | 
				
			
			@ -507,7 +565,6 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
 | 
			
		||||
        backend_cls = self.backend_cls_prefill
 | 
			
		||||
        inputs = (hidden_states.to(torch.float16),
 | 
			
		||||
                  attention_mask.to(torch.int64),
 | 
			
		||||
                  position_ids.to(torch.int64))
 | 
			
		||||
        inputs += (self.layer_norm_0, self.layer_norm_1)
 | 
			
		||||
        hidden_states, past_key, past_value = run_model(
 | 
			
		||||
| 
						 | 
				
			
			@ -557,22 +614,28 @@ def run_decode(
 | 
			
		|||
    head_dim = model.model.layers[layer_start].self_attn.head_dim
 | 
			
		||||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    intermediate_size = model.config.intermediate_size
 | 
			
		||||
    group_size = getattr(model.config, "group_size", 0)
 | 
			
		||||
    layer_weights = []
 | 
			
		||||
    input_layer_norm_weights = []
 | 
			
		||||
    post_attn_layernorm_weights = []
 | 
			
		||||
    layer_indexs = range(layer_start, layer_end)
 | 
			
		||||
    n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
			
		||||
    n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
			
		||||
    for layer_idx in layer_indexs:
 | 
			
		||||
        curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
        attn_layer = curr_layer.self_attn
 | 
			
		||||
        mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
        weights = [
 | 
			
		||||
            (attn_layer.W_pack.weight, attn_layer.W_pack.scale),
 | 
			
		||||
            (attn_layer.o_proj.weight, attn_layer.o_proj.scale),
 | 
			
		||||
            (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
 | 
			
		||||
            (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
			
		||||
            (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
 | 
			
		||||
        ]
 | 
			
		||||
        weights = []
 | 
			
		||||
        for layer_list in [attn_layer.W_pack_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                           mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
 | 
			
		||||
                           mlp_layer.down_proj_dq_list]:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in layer_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -599,6 +662,9 @@ def run_decode(
 | 
			
		|||
        max_seq_len=max_seq_len,
 | 
			
		||||
        transpose_value=transpose_value_cache,
 | 
			
		||||
        do_print=False,
 | 
			
		||||
        n_splits_linear=n_splits_linear,
 | 
			
		||||
        n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
        group_size=group_size
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    dist.barrier()
 | 
			
		||||
| 
						 | 
				
			
			@ -754,23 +820,29 @@ def run_prefill(
 | 
			
		|||
    head_dim = model.model.layers[layer_start].self_attn.head_dim
 | 
			
		||||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    intermediate_size = model.config.intermediate_size
 | 
			
		||||
    group_size = getattr(model.config, "group_size", 0)
 | 
			
		||||
    deocderlayers = []
 | 
			
		||||
    layer_weights = []
 | 
			
		||||
    input_layer_norm_weights = []
 | 
			
		||||
    post_attn_layernorm_weights = []
 | 
			
		||||
    layer_indexs = range(layer_start, layer_end)
 | 
			
		||||
    n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
			
		||||
    n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
			
		||||
    for layer_idx in layer_indexs:
 | 
			
		||||
        curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
        attn_layer = curr_layer.self_attn
 | 
			
		||||
        mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
        weights = [
 | 
			
		||||
            (attn_layer.W_pack.weight, attn_layer.W_pack.scale),
 | 
			
		||||
            (attn_layer.o_proj.weight, attn_layer.o_proj.scale),
 | 
			
		||||
            (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
 | 
			
		||||
            (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
			
		||||
            (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
 | 
			
		||||
        ]
 | 
			
		||||
        weights = []
 | 
			
		||||
        for layer_list in [attn_layer.W_pack_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                           mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
 | 
			
		||||
                           mlp_layer.down_proj_dq_list]:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in layer_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -791,6 +863,9 @@ def run_prefill(
 | 
			
		|||
            intermediate_size=intermediate_size,
 | 
			
		||||
            max_seq_len=max_output_len,
 | 
			
		||||
            transpose_value=transpose_value_cache,
 | 
			
		||||
            n_splits_linear=n_splits_linear,
 | 
			
		||||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        layer_weights.extend(weights)
 | 
			
		||||
| 
						 | 
				
			
			@ -1025,3 +1100,71 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    return baichuan_fused_model_forward
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan2_causal_forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        labels: Optional[torch.LongTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
			
		||||
 | 
			
		||||
    output_attentions = output_attentions if output_attentions is not None \
 | 
			
		||||
        else self.config.output_attentions
 | 
			
		||||
    output_hidden_states = (
 | 
			
		||||
        output_hidden_states if output_hidden_states is not None
 | 
			
		||||
        else self.config.output_hidden_states
 | 
			
		||||
    )
 | 
			
		||||
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 | 
			
		||||
    outputs = self.model(
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
        position_ids=position_ids,
 | 
			
		||||
        past_key_values=past_key_values,
 | 
			
		||||
        inputs_embeds=inputs_embeds,
 | 
			
		||||
        use_cache=use_cache,
 | 
			
		||||
        output_attentions=output_attentions,
 | 
			
		||||
        output_hidden_states=output_hidden_states,
 | 
			
		||||
        return_dict=return_dict,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    hidden_states = outputs[0]
 | 
			
		||||
    # ipex-llm change start
 | 
			
		||||
    hidden_states = reshape_lm_head_input(hidden_states)
 | 
			
		||||
    # ipex-llm change end
 | 
			
		||||
    logits = self.lm_head(hidden_states)
 | 
			
		||||
    loss = None
 | 
			
		||||
    if labels is not None:
 | 
			
		||||
        # Shift so that tokens < n predict n
 | 
			
		||||
        shift_logits = logits[..., :-1, :].contiguous()
 | 
			
		||||
        shift_labels = labels[..., 1:].contiguous()
 | 
			
		||||
        # Flatten the tokens
 | 
			
		||||
        loss_fct = CrossEntropyLoss()
 | 
			
		||||
        shift_logits = shift_logits.view(-1, self.config.vocab_size)
 | 
			
		||||
        shift_labels = shift_labels.view(-1)
 | 
			
		||||
        softmax_normalizer = shift_logits.max(-1).values ** 2
 | 
			
		||||
        z_loss = self.config.z_loss_weight * softmax_normalizer.mean()
 | 
			
		||||
        # Enable model parallelism
 | 
			
		||||
        shift_labels = shift_labels.to(shift_logits.device)
 | 
			
		||||
        loss = loss_fct(shift_logits, shift_labels) + z_loss
 | 
			
		||||
 | 
			
		||||
    if not return_dict:
 | 
			
		||||
        output = (logits,) + outputs[1:]
 | 
			
		||||
        return (loss,) + output if loss is not None else output
 | 
			
		||||
 | 
			
		||||
    return CausalLMOutputWithPast(
 | 
			
		||||
        loss=loss,
 | 
			
		||||
        logits=logits,
 | 
			
		||||
        past_key_values=outputs.past_key_values,
 | 
			
		||||
        hidden_states=outputs.hidden_states,
 | 
			
		||||
        attentions=outputs.attentions,
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -75,10 +75,11 @@ def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down
 | 
			
		|||
    from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2Attention
 | 
			
		||||
    from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
 | 
			
		||||
    attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"]
 | 
			
		||||
    baichuan_attn_module_names = ["W_pack", "o_proj"]
 | 
			
		||||
    mlp_module_names = ["down_proj", "up_proj", "gate_proj"]
 | 
			
		||||
    if (
 | 
			
		||||
        isinstance(module, (Qwen2Attention, LlamaAttention))
 | 
			
		||||
        or module.__class__.__name__ in ['MiniCPMAttention', 'Attention']
 | 
			
		||||
        or module.__class__.__name__ in ['MiniCPMAttention']
 | 
			
		||||
    ):
 | 
			
		||||
        for name in attn_module_names:
 | 
			
		||||
            setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
 | 
			
		||||
| 
						 | 
				
			
			@ -97,3 +98,10 @@ def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down
 | 
			
		|||
                                                            n_splits=n_splits_mlp,
 | 
			
		||||
                                                            load=load))
 | 
			
		||||
            delattr(module, name)
 | 
			
		||||
    elif module.__class__.__name__ == 'Attention' and module.config.model_type == 'baichuan':
 | 
			
		||||
        # baichuan attention
 | 
			
		||||
        for name in baichuan_attn_module_names:
 | 
			
		||||
            setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
 | 
			
		||||
                                                            n_splits=n_splits_hidden_size,
 | 
			
		||||
                                                            load=load))
 | 
			
		||||
            delattr(module, name)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -87,7 +87,7 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
 | 
			
		|||
            model.llm.config.model_type = "llama"
 | 
			
		||||
        model = model.llm
 | 
			
		||||
 | 
			
		||||
    if model.config.model_type in ["qwen2", "llama", "minicpm"]:
 | 
			
		||||
    if model.config.model_type in ["qwen2", "llama", "minicpm", "baichuan"]:
 | 
			
		||||
        from ipex_llm.transformers.npu_models.common import split_linears
 | 
			
		||||
        if quantization_group_size == 0:
 | 
			
		||||
            n_splits_linear = 1
 | 
			
		||||
| 
						 | 
				
			
			@ -245,6 +245,8 @@ def convert_baichuan(
 | 
			
		|||
    modeling_module_name = model.__class__.__module__
 | 
			
		||||
    module = importlib.import_module(modeling_module_name)
 | 
			
		||||
    convert_forward(model, module.BaichuanModel, baichuan_model_forward)
 | 
			
		||||
    from ipex_llm.transformers.npu_models.baichuan_mp import baichuan2_causal_forward
 | 
			
		||||
    convert_forward(model, module.BaichuanForCausalLM, baichuan2_causal_forward)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_minicpm(
 | 
			
		||||
| 
						 | 
				
			
			@ -392,7 +394,7 @@ def optimize_llm(
 | 
			
		|||
        if intra_pp is None:
 | 
			
		||||
            intra_pp = 2
 | 
			
		||||
        if inter_pp is None:
 | 
			
		||||
            inter_pp = 2
 | 
			
		||||
            inter_pp = 2 if group_size == 0 else 4
 | 
			
		||||
        convert_baichuan(model,
 | 
			
		||||
                         max_output_len=max_context_len,
 | 
			
		||||
                         max_prompt_len=max_prompt_len,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -560,23 +560,10 @@ def run_decode(
 | 
			
		|||
        mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
        weights = []
 | 
			
		||||
        if n_splits_linear == 1:
 | 
			
		||||
            for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                        attn_layer.k_proj_dq_list,
 | 
			
		||||
                                        attn_layer.v_proj_dq_list,
 | 
			
		||||
                                        attn_layer.o_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.up_proj_dq_list):
 | 
			
		||||
                weights.append((q.weight, q.scale))
 | 
			
		||||
                weights.append((k.weight, k.scale))
 | 
			
		||||
                weights.append((v.weight, v.scale))
 | 
			
		||||
                weights.append((o.weight, o.scale))
 | 
			
		||||
                weights.append((g.weight, g.scale))
 | 
			
		||||
                weights.append((u.weight, u.scale))
 | 
			
		||||
        else:
 | 
			
		||||
        for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                           attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                               mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                           mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
 | 
			
		||||
                           mlp_layer.down_proj_dq_list]:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in layer_list:
 | 
			
		||||
| 
						 | 
				
			
			@ -584,17 +571,6 @@ def run_decode(
 | 
			
		|||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        if n_splits_down_proj == 1:
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                weights.append((l.weight, l.scale))
 | 
			
		||||
        else:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
			
		||||
            cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
            cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -844,40 +820,15 @@ def run_prefill(
 | 
			
		|||
 | 
			
		||||
                weights = []
 | 
			
		||||
 | 
			
		||||
                if n_splits_linear == 1:
 | 
			
		||||
                    for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                                attn_layer.k_proj_dq_list,
 | 
			
		||||
                                                attn_layer.v_proj_dq_list,
 | 
			
		||||
                                                attn_layer.o_proj_dq_list,
 | 
			
		||||
                                                mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                                mlp_layer.up_proj_dq_list):
 | 
			
		||||
                        weights.append((q.weight, q.scale))
 | 
			
		||||
                        weights.append((k.weight, k.scale))
 | 
			
		||||
                        weights.append((v.weight, v.scale))
 | 
			
		||||
                        weights.append((o.weight, o.scale))
 | 
			
		||||
                        weights.append((g.weight, g.scale))
 | 
			
		||||
                        weights.append((u.weight, u.scale))
 | 
			
		||||
                else:
 | 
			
		||||
                for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                                   attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                                       mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                                   mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
 | 
			
		||||
                                   mlp_layer.down_proj_dq_list]:
 | 
			
		||||
                    l_weights = []
 | 
			
		||||
                    scales = []
 | 
			
		||||
                    for l in layer_list:
 | 
			
		||||
                        l_weights.append(l.weight)
 | 
			
		||||
                        scales.append(l.scale)
 | 
			
		||||
                        weights.append((torch.stack(l_weights, axis=0),
 | 
			
		||||
                                        torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
                if n_splits_down_proj == 1:
 | 
			
		||||
                    for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                        weights.append((l.weight, l.scale))
 | 
			
		||||
                else:
 | 
			
		||||
                    l_weights = []
 | 
			
		||||
                    scales = []
 | 
			
		||||
                    for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                        l_weights.append(l.weight)
 | 
			
		||||
                        scales.append(l.scale)
 | 
			
		||||
                    weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
                if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -540,23 +540,10 @@ def run_decode(
 | 
			
		|||
        mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
        weights = []
 | 
			
		||||
        if n_splits_linear == 1:
 | 
			
		||||
            for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                        attn_layer.k_proj_dq_list,
 | 
			
		||||
                                        attn_layer.v_proj_dq_list,
 | 
			
		||||
                                        attn_layer.o_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.up_proj_dq_list):
 | 
			
		||||
                weights.append((q.weight, q.scale))
 | 
			
		||||
                weights.append((k.weight, k.scale))
 | 
			
		||||
                weights.append((v.weight, v.scale))
 | 
			
		||||
                weights.append((o.weight, o.scale))
 | 
			
		||||
                weights.append((g.weight, g.scale))
 | 
			
		||||
                weights.append((u.weight, u.scale))
 | 
			
		||||
        else:
 | 
			
		||||
        for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                           attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                               mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                           mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
 | 
			
		||||
                           mlp_layer.down_proj_dq_list]:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in layer_list:
 | 
			
		||||
| 
						 | 
				
			
			@ -564,17 +551,6 @@ def run_decode(
 | 
			
		|||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        if n_splits_down_proj == 1:
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                weights.append((l.weight, l.scale))
 | 
			
		||||
        else:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
        layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -783,24 +759,10 @@ def run_prefill(
 | 
			
		|||
        mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
        weights = []
 | 
			
		||||
 | 
			
		||||
        if n_splits_linear == 1:
 | 
			
		||||
            for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                        attn_layer.k_proj_dq_list,
 | 
			
		||||
                                        attn_layer.v_proj_dq_list,
 | 
			
		||||
                                        attn_layer.o_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.up_proj_dq_list):
 | 
			
		||||
                weights.append((q.weight, q.scale))
 | 
			
		||||
                weights.append((k.weight, k.scale))
 | 
			
		||||
                weights.append((v.weight, v.scale))
 | 
			
		||||
                weights.append((o.weight, o.scale))
 | 
			
		||||
                weights.append((g.weight, g.scale))
 | 
			
		||||
                weights.append((u.weight, u.scale))
 | 
			
		||||
        else:
 | 
			
		||||
        for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                           attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                               mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                           mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
 | 
			
		||||
                           mlp_layer.down_proj_dq_list]:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in layer_list:
 | 
			
		||||
| 
						 | 
				
			
			@ -808,17 +770,6 @@ def run_prefill(
 | 
			
		|||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        if n_splits_down_proj == 1:
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                weights.append((l.weight, l.scale))
 | 
			
		||||
        else:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -138,13 +138,18 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
                  use_prefill_sdp=False):
 | 
			
		||||
        hidden_size = num_heads * head_dim
 | 
			
		||||
        num_key_value_groups = num_heads // num_key_value_heads
 | 
			
		||||
        if self.n_splits_linear == 1:
 | 
			
		||||
        if self.n_splits_linear != 1:
 | 
			
		||||
            hidden_states = self.unsqueeze(hidden_states, axis=0)
 | 
			
		||||
 | 
			
		||||
        query_states = self.linear(
 | 
			
		||||
            hidden_states,
 | 
			
		||||
            num_heads * head_dim,
 | 
			
		||||
            hidden_size,
 | 
			
		||||
            bias=False,
 | 
			
		||||
            wt_dtype=self.dtype,
 | 
			
		||||
            n_splits=self.n_splits_linear,
 | 
			
		||||
            scale_factor=(self.group_size == 0),
 | 
			
		||||
            is_prefill=(mode == "prefill")
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        key_states = self.linear(
 | 
			
		||||
| 
						 | 
				
			
			@ -153,6 +158,9 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
            hidden_size,
 | 
			
		||||
            bias=False,
 | 
			
		||||
            wt_dtype=self.dtype,
 | 
			
		||||
            n_splits=self.n_splits_linear,
 | 
			
		||||
            scale_factor=(self.group_size == 0),
 | 
			
		||||
            is_prefill=(mode == "prefill")
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        value_states = self.linear(
 | 
			
		||||
| 
						 | 
				
			
			@ -161,24 +169,10 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
            hidden_size,
 | 
			
		||||
            bias=False,
 | 
			
		||||
            wt_dtype=self.dtype,
 | 
			
		||||
            n_splits=self.n_splits_linear,
 | 
			
		||||
            scale_factor=(self.group_size == 0),
 | 
			
		||||
            is_prefill=(mode == "prefill")
 | 
			
		||||
        )
 | 
			
		||||
        else:
 | 
			
		||||
            hidden_states = self.unsqueeze(hidden_states, axis=0)
 | 
			
		||||
            query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
 | 
			
		||||
                                                hidden_size, self.n_splits_linear,
 | 
			
		||||
                                                wt_dtype=self.dtype,
 | 
			
		||||
                                                scale_factor=(self.group_size == 0),
 | 
			
		||||
                                                is_prefill=(mode == "prefill"))
 | 
			
		||||
            key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
 | 
			
		||||
                                              hidden_size, self.n_splits_linear,
 | 
			
		||||
                                              wt_dtype=self.dtype,
 | 
			
		||||
                                              scale_factor=(self.group_size == 0),
 | 
			
		||||
                                              is_prefill=(mode == "prefill"))
 | 
			
		||||
            value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
 | 
			
		||||
                                                hidden_size, self.n_splits_linear,
 | 
			
		||||
                                                wt_dtype=self.dtype,
 | 
			
		||||
                                                scale_factor=(self.group_size == 0),
 | 
			
		||||
                                                is_prefill=(mode == "prefill"))
 | 
			
		||||
 | 
			
		||||
        if q_bias is not None:
 | 
			
		||||
            query_states = query_states + q_bias
 | 
			
		||||
| 
						 | 
				
			
			@ -263,15 +257,12 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
        attn_output = self.transpose(attn_output, [0, 2, 1, 3])
 | 
			
		||||
        attn_output = self.reshape(attn_output, [1, seq_len, hidden_size])
 | 
			
		||||
 | 
			
		||||
        if self.n_splits_linear == 1:
 | 
			
		||||
        attn_output = self.linear(
 | 
			
		||||
                attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
 | 
			
		||||
                                               self.n_splits_linear, wt_dtype=self.dtype,
 | 
			
		||||
            attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype,
 | 
			
		||||
            n_splits=self.n_splits_linear,
 | 
			
		||||
            scale_factor=(self.group_size == 0),
 | 
			
		||||
                                               is_prefill=(mode == "prefill"))
 | 
			
		||||
            is_prefill=(mode == "prefill")
 | 
			
		||||
        )
 | 
			
		||||
        return attn_output, new_key_states, new_value_states
 | 
			
		||||
 | 
			
		||||
    def paraformer_layer_norm(self, hidden_states, layernorm_weight, layernorm_bias):
 | 
			
		||||
| 
						 | 
				
			
			@ -434,38 +425,26 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
        return w_2
 | 
			
		||||
 | 
			
		||||
    def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
 | 
			
		||||
        if self.n_splits_linear == 1:
 | 
			
		||||
        mm1 = self.linear(
 | 
			
		||||
            hidden_states, self.intermediate_size, self.hidden_size, bias=False,
 | 
			
		||||
                wt_dtype=self.dtype
 | 
			
		||||
            wt_dtype=self.dtype, n_splits=self.n_splits_linear,
 | 
			
		||||
            scale_factor=(self.group_size == 0),
 | 
			
		||||
            is_prefill=(mode == "prefill")
 | 
			
		||||
        )
 | 
			
		||||
        mm2 = self.linear(
 | 
			
		||||
            hidden_states, self.intermediate_size, self.hidden_size, bias=False,
 | 
			
		||||
                wt_dtype=self.dtype
 | 
			
		||||
            wt_dtype=self.dtype, n_splits=self.n_splits_linear,
 | 
			
		||||
            scale_factor=(self.group_size == 0),
 | 
			
		||||
            is_prefill=(mode == "prefill")
 | 
			
		||||
        )  # type: ignore[attr-defined]
 | 
			
		||||
        mm1 = self.eltwise_mul(self.swish(mm1), mm2)  # type: ignore[attr-defined]
 | 
			
		||||
        else:
 | 
			
		||||
            invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
 | 
			
		||||
            mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
 | 
			
		||||
                                       self.n_splits_linear, wt_dtype=self.dtype,
 | 
			
		||||
                                       scale_factor=(self.group_size == 0),
 | 
			
		||||
                                       is_prefill=(mode == "prefill"))
 | 
			
		||||
            mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
 | 
			
		||||
                                       self.n_splits_linear, wt_dtype=self.dtype,
 | 
			
		||||
                                       scale_factor=(self.group_size == 0),
 | 
			
		||||
                                       is_prefill=(mode == "prefill"))
 | 
			
		||||
            mm1 = self.eltwise_mul(self.swish(mm1), mm2)  # type: ignore[attr-defined]
 | 
			
		||||
 | 
			
		||||
        if self.n_splits_down_proj == 1:
 | 
			
		||||
        hidden_states = self.linear(
 | 
			
		||||
                mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
 | 
			
		||||
            hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
 | 
			
		||||
                                                 self.n_splits_down_proj, wt_dtype=self.dtype,
 | 
			
		||||
            mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype,
 | 
			
		||||
            n_splits=self.n_splits_down_proj,
 | 
			
		||||
            scale_factor=(self.group_size == 0),
 | 
			
		||||
                                                 is_prefill=(mode == "prefill"))
 | 
			
		||||
            is_prefill=(mode == "prefill")
 | 
			
		||||
        )
 | 
			
		||||
        return hidden_states
 | 
			
		||||
 | 
			
		||||
    def layer_norm(self, hidden_states, layernorm_weight):
 | 
			
		||||
| 
						 | 
				
			
			@ -571,8 +550,26 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
        self.input_ops.append(op)
 | 
			
		||||
        return op
 | 
			
		||||
 | 
			
		||||
    def linear(self, *args, **kwargs):
 | 
			
		||||
        op = super().linear(*args, **kwargs)
 | 
			
		||||
    def linear(self,
 | 
			
		||||
               input_node: ctypes._Pointer,
 | 
			
		||||
               output_channels: int,
 | 
			
		||||
               input_channels: int,
 | 
			
		||||
               bias: Optional[bool] = False,
 | 
			
		||||
               act_dtype: npt.DTypeLike = np.float16,
 | 
			
		||||
               wt_dtype: npt.DTypeLike = np.float16,
 | 
			
		||||
               n_splits: int = 1,
 | 
			
		||||
               scale_factor: bool = True,
 | 
			
		||||
               is_prefill: bool = False):
 | 
			
		||||
        if n_splits == 1:
 | 
			
		||||
            op = super().linear(input_node, output_channels,
 | 
			
		||||
                                input_channels, bias, act_dtype,
 | 
			
		||||
                                wt_dtype, scale_factor=scale_factor)
 | 
			
		||||
        else:
 | 
			
		||||
            op = super().dq_split_linear(input_node, n_splits,
 | 
			
		||||
                                         output_channels, input_channels,
 | 
			
		||||
                                         bias=bias, act_dtype=act_dtype,
 | 
			
		||||
                                         wt_dtype=wt_dtype, scale_factor=scale_factor,
 | 
			
		||||
                                         is_prefill=is_prefill)
 | 
			
		||||
        self.linear_ops.append(op)
 | 
			
		||||
        return op
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -586,23 +586,10 @@ def run_decode(
 | 
			
		|||
        mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
        weights = []
 | 
			
		||||
        if n_splits_linear == 1:
 | 
			
		||||
            for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                        attn_layer.k_proj_dq_list,
 | 
			
		||||
                                        attn_layer.v_proj_dq_list,
 | 
			
		||||
                                        attn_layer.o_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.up_proj_dq_list):
 | 
			
		||||
                weights.append((q.weight, q.scale))
 | 
			
		||||
                weights.append((k.weight, k.scale))
 | 
			
		||||
                weights.append((v.weight, v.scale))
 | 
			
		||||
                weights.append((o.weight, o.scale))
 | 
			
		||||
                weights.append((g.weight, g.scale))
 | 
			
		||||
                weights.append((u.weight, u.scale))
 | 
			
		||||
        else:
 | 
			
		||||
        for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                           attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                               mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                           mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
 | 
			
		||||
                           mlp_layer.down_proj_dq_list]:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in layer_list:
 | 
			
		||||
| 
						 | 
				
			
			@ -610,17 +597,6 @@ def run_decode(
 | 
			
		|||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        if n_splits_down_proj == 1:
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                weights.append((l.weight, l.scale))
 | 
			
		||||
        else:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
        layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -839,23 +815,10 @@ def run_prefill(
 | 
			
		|||
        mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
        weights = []
 | 
			
		||||
        if n_splits_linear == 1:
 | 
			
		||||
            for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                        attn_layer.k_proj_dq_list,
 | 
			
		||||
                                        attn_layer.v_proj_dq_list,
 | 
			
		||||
                                        attn_layer.o_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.up_proj_dq_list):
 | 
			
		||||
                weights.append((q.weight, q.scale))
 | 
			
		||||
                weights.append((k.weight, k.scale))
 | 
			
		||||
                weights.append((v.weight, v.scale))
 | 
			
		||||
                weights.append((o.weight, o.scale))
 | 
			
		||||
                weights.append((g.weight, g.scale))
 | 
			
		||||
                weights.append((u.weight, u.scale))
 | 
			
		||||
        else:
 | 
			
		||||
        for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                           attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                               mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                           mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
 | 
			
		||||
                           mlp_layer.down_proj_dq_list]:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in layer_list:
 | 
			
		||||
| 
						 | 
				
			
			@ -863,17 +826,6 @@ def run_prefill(
 | 
			
		|||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        if n_splits_down_proj == 1:
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                weights.append((l.weight, l.scale))
 | 
			
		||||
        else:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,7 +28,17 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
			
		|||
    vocab_size = model.config.vocab_size
 | 
			
		||||
    model_norm = model.model.norm
 | 
			
		||||
    lm_head = model.lm_head
 | 
			
		||||
    if n_splits_linear == 1:
 | 
			
		||||
        weights = [(lm_head.weight, lm_head.scale)]
 | 
			
		||||
    else:
 | 
			
		||||
        lm_heads = lm_head.lm_heads
 | 
			
		||||
        lm_head_weights = []
 | 
			
		||||
        scales = []
 | 
			
		||||
        for l in lm_heads:
 | 
			
		||||
            lm_head_weights.append(l.weight)
 | 
			
		||||
            scales.append(l.scale)
 | 
			
		||||
        weights = [(torch.stack(lm_head_weights, axis=0),
 | 
			
		||||
                    torch.stack(scales, axis=0))]
 | 
			
		||||
    if isinstance(weights[0], tuple):
 | 
			
		||||
        np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
 | 
			
		||||
    else:  # FP16 Linear
 | 
			
		||||
| 
						 | 
				
			
			@ -44,13 +54,17 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
 | 
			
		|||
        dtype=np_dtype,
 | 
			
		||||
        model_norm_weight=model_norm.weight.to(torch.float16),
 | 
			
		||||
        vocab_size=vocab_size,
 | 
			
		||||
        n_splits=n_splits_linear
 | 
			
		||||
    )
 | 
			
		||||
    last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
 | 
			
		||||
 | 
			
		||||
    # save weights bins files
 | 
			
		||||
    if n_splits_linear == 1:
 | 
			
		||||
        weight_numpy = [
 | 
			
		||||
            lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
 | 
			
		||||
        ]
 | 
			
		||||
    else:
 | 
			
		||||
        weight_numpy = [v.numpy() for v in weights[0]]
 | 
			
		||||
 | 
			
		||||
    for idx, weight in enumerate(weight_numpy):
 | 
			
		||||
        bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
 | 
			
		||||
| 
						 | 
				
			
			@ -83,17 +97,15 @@ def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj
 | 
			
		|||
    mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
    weights = []
 | 
			
		||||
    if n_splits_linear == 1:
 | 
			
		||||
        weights = [
 | 
			
		||||
            (attn_layer.W_pack.weight, attn_layer.W_pack.scale),
 | 
			
		||||
            (attn_layer.o_proj.weight, attn_layer.o_proj.scale),
 | 
			
		||||
            (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
 | 
			
		||||
            (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
			
		||||
            (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
 | 
			
		||||
        ]
 | 
			
		||||
    else:
 | 
			
		||||
        # TODO
 | 
			
		||||
        pass
 | 
			
		||||
    for layer_list in [attn_layer.W_pack_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                       mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
 | 
			
		||||
                       mlp_layer.down_proj_dq_list]:
 | 
			
		||||
        l_weights = []
 | 
			
		||||
        scales = []
 | 
			
		||||
        for l in layer_list:
 | 
			
		||||
            l_weights.append(l.weight)
 | 
			
		||||
            scales.append(l.scale)
 | 
			
		||||
        weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
    cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
    cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -119,6 +131,9 @@ def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj
 | 
			
		|||
        mode="decode",
 | 
			
		||||
        transpose_value=transpose_value_cache,
 | 
			
		||||
        dtype=np_dtype,
 | 
			
		||||
        n_splits_linear=n_splits_linear,
 | 
			
		||||
        n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
        group_size=group_size
 | 
			
		||||
    )
 | 
			
		||||
    rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
 | 
			
		||||
                                                        f"decoder_layer_{layer_idx}",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -91,21 +91,21 @@ class LowBitLLMLMHead(LLMBaseNNFactory):
 | 
			
		|||
        self.head_dim = self.hidden_size // self.num_heads
 | 
			
		||||
 | 
			
		||||
        # define input, the order self.parameter matters
 | 
			
		||||
        if n_splits == 1:
 | 
			
		||||
            input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
 | 
			
		||||
        else:
 | 
			
		||||
            input = self.create_input_op((1, self.batch_size, self.hidden_size))
 | 
			
		||||
 | 
			
		||||
        hidden_states = input
 | 
			
		||||
 | 
			
		||||
        # model norm and lm head
 | 
			
		||||
        model_norm_weight = self.constant(model_norm_weight)
 | 
			
		||||
        hidden_states = self.layer_norm(hidden_states, model_norm_weight)
 | 
			
		||||
        if n_splits == 1:
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.linear(
 | 
			
		||||
                hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            hidden_states = self.dq_split_linear(
 | 
			
		||||
                hidden_states, self.vocab_size, self.hidden_size, n_splits,
 | 
			
		||||
                wt_dtype=dtype, scale_factor=False
 | 
			
		||||
            hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype,
 | 
			
		||||
            n_splits=n_splits,
 | 
			
		||||
            scale_factor=(n_splits == 1),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # define outputs
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -174,40 +174,15 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
    mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
    weights = []
 | 
			
		||||
    if n_splits_linear == 1:
 | 
			
		||||
        for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                    attn_layer.k_proj_dq_list,
 | 
			
		||||
                                    attn_layer.v_proj_dq_list,
 | 
			
		||||
                                    attn_layer.o_proj_dq_list,
 | 
			
		||||
                                    mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                    mlp_layer.up_proj_dq_list):
 | 
			
		||||
            weights.append((q.weight, q.scale))
 | 
			
		||||
            weights.append((k.weight, k.scale))
 | 
			
		||||
            weights.append((v.weight, v.scale))
 | 
			
		||||
            weights.append((o.weight, o.scale))
 | 
			
		||||
            weights.append((g.weight, g.scale))
 | 
			
		||||
            weights.append((u.weight, u.scale))
 | 
			
		||||
    else:
 | 
			
		||||
    for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                       attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                           mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                       mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
 | 
			
		||||
                       mlp_layer.down_proj_dq_list]:
 | 
			
		||||
        l_weights = []
 | 
			
		||||
        scales = []
 | 
			
		||||
        for l in layer_list:
 | 
			
		||||
            l_weights.append(l.weight)
 | 
			
		||||
            scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0),
 | 
			
		||||
                            torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
    if n_splits_down_proj == 1:
 | 
			
		||||
        for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
            weights.append((l.weight, l.scale))
 | 
			
		||||
    else:
 | 
			
		||||
        l_weights = []
 | 
			
		||||
        scales = []
 | 
			
		||||
        for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
            l_weights.append(l.weight)
 | 
			
		||||
            scales.append(l.scale)
 | 
			
		||||
        weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
    if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -109,36 +109,22 @@ class MiniCPMLMHead(LLMBaseNNFactory):
 | 
			
		|||
        hidden_states = self.layer_norm(hidden_states, model_norm_weight)
 | 
			
		||||
        if vocab_size == 122753:
 | 
			
		||||
            # for MiniCPM-2B-sft-bf16
 | 
			
		||||
            if n_splits == 1:
 | 
			
		||||
            hidden_states_1 = self.linear(
 | 
			
		||||
                    hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
			
		||||
                hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype,
 | 
			
		||||
                n_splits=n_splits, scale_factor=(n_splits == 1)
 | 
			
		||||
            )
 | 
			
		||||
            hidden_states_2 = self.linear(
 | 
			
		||||
                    hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                hidden_states_1 = self.dq_split_linear(
 | 
			
		||||
                    hidden_states, 73440, self.hidden_size,
 | 
			
		||||
                    n_splits=n_splits, wt_dtype=dtype, scale_factor=False
 | 
			
		||||
                )
 | 
			
		||||
                hidden_states_2 = self.dq_split_linear(
 | 
			
		||||
                    hidden_states, 73440, self.hidden_size,
 | 
			
		||||
                    n_splits=n_splits, wt_dtype=dtype, scale_factor=False
 | 
			
		||||
                hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype,
 | 
			
		||||
                n_splits=n_splits, scale_factor=(n_splits == 1)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            hidden_states_2 = self.slice(hidden_states_2, begin=[0, 0, 0], end=[1, 1, 49313])
 | 
			
		||||
            hidden_states = self.concat(hidden_states_1, hidden_states_2, axis=2)
 | 
			
		||||
        else:
 | 
			
		||||
            # for MiniCPM-1B-sft-bf16
 | 
			
		||||
            if n_splits == 1:
 | 
			
		||||
            hidden_states = self.linear(
 | 
			
		||||
                hidden_states, self.vocab_size, self.hidden_size, bias=False,
 | 
			
		||||
                    wt_dtype=self.dtype
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                hidden_states = self.dq_split_linear(
 | 
			
		||||
                    hidden_states, self.vocab_size, self.hidden_size,
 | 
			
		||||
                    n_splits=n_splits, wt_dtype=dtype, scale_factor=False
 | 
			
		||||
                wt_dtype=self.dtype, n_splits=n_splits, scale_factor=(n_splits == 1)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # define outputs
 | 
			
		||||
| 
						 | 
				
			
			@ -245,40 +231,15 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
    mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
    weights = []
 | 
			
		||||
    if n_splits_linear == 1:
 | 
			
		||||
        for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                    attn_layer.k_proj_dq_list,
 | 
			
		||||
                                    attn_layer.v_proj_dq_list,
 | 
			
		||||
                                    attn_layer.o_proj_dq_list,
 | 
			
		||||
                                    mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                    mlp_layer.up_proj_dq_list):
 | 
			
		||||
            weights.append((q.weight, q.scale))
 | 
			
		||||
            weights.append((k.weight, k.scale))
 | 
			
		||||
            weights.append((v.weight, v.scale))
 | 
			
		||||
            weights.append((o.weight, o.scale))
 | 
			
		||||
            weights.append((g.weight, g.scale))
 | 
			
		||||
            weights.append((u.weight, u.scale))
 | 
			
		||||
    else:
 | 
			
		||||
    for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                       attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                           mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                       mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
 | 
			
		||||
                       mlp_layer.down_proj_dq_list]:
 | 
			
		||||
        l_weights = []
 | 
			
		||||
        scales = []
 | 
			
		||||
        for l in layer_list:
 | 
			
		||||
            l_weights.append(l.weight)
 | 
			
		||||
            scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0),
 | 
			
		||||
                            torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
    if n_splits_down_proj == 1:
 | 
			
		||||
        for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
            weights.append((l.weight, l.scale))
 | 
			
		||||
    else:
 | 
			
		||||
        l_weights = []
 | 
			
		||||
        scales = []
 | 
			
		||||
        for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
            l_weights.append(l.weight)
 | 
			
		||||
            scales.append(l.scale)
 | 
			
		||||
        weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
    cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -99,23 +99,10 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
    mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
    weights = []
 | 
			
		||||
    if n_splits_linear == 1:
 | 
			
		||||
        for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                    attn_layer.k_proj_dq_list,
 | 
			
		||||
                                    attn_layer.v_proj_dq_list,
 | 
			
		||||
                                    attn_layer.o_proj_dq_list,
 | 
			
		||||
                                    mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                    mlp_layer.up_proj_dq_list):
 | 
			
		||||
            weights.append((q.weight, q.scale))
 | 
			
		||||
            weights.append((k.weight, k.scale))
 | 
			
		||||
            weights.append((v.weight, v.scale))
 | 
			
		||||
            weights.append((o.weight, o.scale))
 | 
			
		||||
            weights.append((g.weight, g.scale))
 | 
			
		||||
            weights.append((u.weight, u.scale))
 | 
			
		||||
    else:
 | 
			
		||||
    for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                       attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                           mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                       mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
 | 
			
		||||
                       mlp_layer.down_proj_dq_list]:
 | 
			
		||||
        l_weights = []
 | 
			
		||||
        scales = []
 | 
			
		||||
        for l in layer_list:
 | 
			
		||||
| 
						 | 
				
			
			@ -123,17 +110,6 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
 | 
			
		|||
            scales.append(l.scale)
 | 
			
		||||
        weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
    if n_splits_down_proj == 1:
 | 
			
		||||
        for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
            weights.append((l.weight, l.scale))
 | 
			
		||||
    else:
 | 
			
		||||
        l_weights = []
 | 
			
		||||
        scales = []
 | 
			
		||||
        for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
            l_weights.append(l.weight)
 | 
			
		||||
            scales.append(l.scale)
 | 
			
		||||
        weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
    q_bias = attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16)
 | 
			
		||||
    k_bias = attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16)
 | 
			
		||||
    v_bias = attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue