[NPU] Support llama groupwise (#12260)
* support llama gw * support llama gw lm_head * fix style * remove unused code
This commit is contained in:
		
							parent
							
								
									48fc63887d
								
							
						
					
					
						commit
						b5e663854b
					
				
					 5 changed files with 143 additions and 74 deletions
				
			
		| 
						 | 
				
			
			@ -414,7 +414,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
            optimize_llm(model)
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
                cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
 | 
			
		||||
                                 *model_args, **kwargs)
 | 
			
		||||
                                 quantization_group_size, *model_args, **kwargs)
 | 
			
		||||
                create_npu_kernels(model)
 | 
			
		||||
 | 
			
		||||
        if is_sharded:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -59,3 +59,23 @@ def split_linear(module, module_name, n_splits=2):
 | 
			
		|||
        new_linear.weight = torch.nn.Parameter(weight.contiguous(), requires_grad=False)
 | 
			
		||||
        linear_list.add_module(f"{module_name}_dq_{idx}", new_linear)
 | 
			
		||||
    return linear_list
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down_proj=2):
 | 
			
		||||
    from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2Attention
 | 
			
		||||
    from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
 | 
			
		||||
    attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"]
 | 
			
		||||
    mlp_module_names = ["down_proj", "up_proj", "gate_proj"]
 | 
			
		||||
    if isinstance(module, (Qwen2Attention, LlamaAttention)):
 | 
			
		||||
        for name in attn_module_names:
 | 
			
		||||
            setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
 | 
			
		||||
                                                            n_splits=n_splits_hidden_size))
 | 
			
		||||
            delattr(module, name)
 | 
			
		||||
    elif isinstance(module, (Qwen2MLP, LlamaMLP)):
 | 
			
		||||
        for name in mlp_module_names:
 | 
			
		||||
            n_splits_mlp = n_splits_hidden_size
 | 
			
		||||
            if name == 'down_proj':
 | 
			
		||||
                n_splits_mlp = n_splits_down_proj
 | 
			
		||||
            setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
 | 
			
		||||
                                                            n_splits=n_splits_mlp))
 | 
			
		||||
            delattr(module, name)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -87,8 +87,8 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
 | 
			
		|||
            model.llm.config.model_type = "llama"
 | 
			
		||||
        model = model.llm
 | 
			
		||||
 | 
			
		||||
    if model.config.model_type == "qwen2":
 | 
			
		||||
        from ipex_llm.transformers.npu_models.qwen2_mp import split_linears
 | 
			
		||||
    if model.config.model_type in ["qwen2", "llama"]:
 | 
			
		||||
        from ipex_llm.transformers.npu_models.common import split_linears
 | 
			
		||||
 | 
			
		||||
        if quantization_group_size == 0:
 | 
			
		||||
            n_splits_linear = 1
 | 
			
		||||
| 
						 | 
				
			
			@ -108,15 +108,19 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
 | 
			
		|||
        model.apply(lambda m: split_linears(m, n_splits_hidden_size=n_splits_linear,
 | 
			
		||||
                                            n_splits_down_proj=n_splits_down_proj))
 | 
			
		||||
 | 
			
		||||
        if quantization_group_size != 0:
 | 
			
		||||
            split_num = model.config.hidden_size // quantization_group_size
 | 
			
		||||
            new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
 | 
			
		||||
                                       bias=model.lm_head.bias, use_split=True)
 | 
			
		||||
            del model.lm_head
 | 
			
		||||
            model.lm_head = new_lm_head
 | 
			
		||||
 | 
			
		||||
    if model.config.model_type == "qwen2":
 | 
			
		||||
        # for Qwen2-7B-Insturct, divide lm_head into 14 parts
 | 
			
		||||
        if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
 | 
			
		||||
                not cpu_lm_head:
 | 
			
		||||
            # Do not split lm_head and use sym_int8 instead when mixed_precison is True
 | 
			
		||||
            if quantization_group_size != 0:
 | 
			
		||||
                split_num = model.config.hidden_size // quantization_group_size
 | 
			
		||||
                new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
 | 
			
		||||
                                           bias=model.lm_head.bias, use_split=True)
 | 
			
		||||
            else:
 | 
			
		||||
            if quantization_group_size == 0:
 | 
			
		||||
                # Do not split lm_head and use sym_int8 instead when mixed_precison is True
 | 
			
		||||
                is_split = (not mixed_precision) and qtype == "sym_int4_rtn"
 | 
			
		||||
                split_num = 14 if is_split else 1
 | 
			
		||||
| 
						 | 
				
			
			@ -163,7 +167,7 @@ def optimize_llm(
 | 
			
		|||
        if intra_pp is None:
 | 
			
		||||
            intra_pp = 2
 | 
			
		||||
        if inter_pp is None:
 | 
			
		||||
            inter_pp = 2
 | 
			
		||||
            inter_pp = 2 if group_size == 0 else 8
 | 
			
		||||
 | 
			
		||||
        from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward
 | 
			
		||||
        from ipex_llm.transformers.npu_models.llama_mp import DecodeRunner, PrefillRunner
 | 
			
		||||
| 
						 | 
				
			
			@ -226,11 +230,6 @@ def optimize_llm(
 | 
			
		|||
        from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
 | 
			
		||||
        from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward
 | 
			
		||||
        convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)
 | 
			
		||||
 | 
			
		||||
        # for Qwen2-7B-Insturct, divide lm_head into 14 parts
 | 
			
		||||
        if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
 | 
			
		||||
                isinstance(model.lm_head, SlicedLMHead):
 | 
			
		||||
            model.lm_head.get_fused_lm_head()
 | 
			
		||||
    elif model.config.model_type == "minicpm":
 | 
			
		||||
        # for minicpm-1b
 | 
			
		||||
        if intra_pp is None:
 | 
			
		||||
| 
						 | 
				
			
			@ -299,3 +298,6 @@ def optimize_llm(
 | 
			
		|||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        convert_forward(model, module.BaichuanModel, baichuan_model_forward)
 | 
			
		||||
 | 
			
		||||
    if isinstance(model.lm_head, SlicedLMHead):
 | 
			
		||||
        model.lm_head.get_fused_lm_head()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -67,12 +67,18 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        device: str = "NPU",
 | 
			
		||||
        rms_norm_eps,
 | 
			
		||||
        intermediate_size,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(max_seq_len=max_seq_len,
 | 
			
		||||
                         transpose_value=transpose_value,
 | 
			
		||||
                         dtype=dtype,
 | 
			
		||||
                         profile=profile,
 | 
			
		||||
                         device=device)
 | 
			
		||||
                         device=device,
 | 
			
		||||
                         n_splits_linear=n_splits_linear,
 | 
			
		||||
                         n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
                         group_size=group_size)
 | 
			
		||||
        self.max_seq_len = max_seq_len
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.dtype = dtype
 | 
			
		||||
| 
						 | 
				
			
			@ -215,7 +221,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
 | 
			
		|||
        hidden_states = self.eltwise_add(residual, attn_output)
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
        hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
 | 
			
		||||
        hidden_states = self.mlp(hidden_states)
 | 
			
		||||
        hidden_states = self.mlp(hidden_states, self.seq_len, self.mode)
 | 
			
		||||
        hidden_states = self.eltwise_add(residual, hidden_states)
 | 
			
		||||
        hidden_states = self.convert_to_fp16(hidden_states)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -241,6 +247,9 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        max_seq_len: int = 1024,
 | 
			
		||||
        transpose_value: bool = False,
 | 
			
		||||
        do_print: bool = False,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -250,6 +259,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        for w in parameters:
 | 
			
		||||
            if isinstance(w, tuple):  # from QuantizedLinear
 | 
			
		||||
                op_parameters.append((w[0].numpy(), w[1].numpy()))
 | 
			
		||||
            elif w.dtype in [torch.int8, torch.uint8]:    # QuantizedLinear weight
 | 
			
		||||
                op_parameters.append(w.numpy())
 | 
			
		||||
            elif isinstance(w, np.ndarray):     # scale
 | 
			
		||||
                op_parameters.append(w)
 | 
			
		||||
            else:
 | 
			
		||||
                op_parameters.append(w.to(torch.float16).numpy())
 | 
			
		||||
        self.op_parameters = op_parameters
 | 
			
		||||
| 
						 | 
				
			
			@ -258,6 +271,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
        self.transpose_value = transpose_value
 | 
			
		||||
        if isinstance(parameters[0], tuple):
 | 
			
		||||
            np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
 | 
			
		||||
        elif parameters[0].dtype == torch.int8:
 | 
			
		||||
            np_dtype = np.int8
 | 
			
		||||
        elif parameters[0].dtype == torch.uint8:
 | 
			
		||||
            np_dtype = np.uint8
 | 
			
		||||
        else:  # FP16 Linear
 | 
			
		||||
            np_dtype = np.float16
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -292,6 +309,9 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
                mode="decode",
 | 
			
		||||
                transpose_value=self.transpose_value,
 | 
			
		||||
                dtype=np_dtype,
 | 
			
		||||
                n_splits_linear=n_splits_linear,
 | 
			
		||||
                n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
                group_size=group_size
 | 
			
		||||
            )
 | 
			
		||||
            self.backend_decoders.append(decoder)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -367,6 +387,9 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
        intermediate_size,
 | 
			
		||||
        max_seq_len: int = 128,
 | 
			
		||||
        transpose_value: bool = False,
 | 
			
		||||
        n_splits_linear: int = 1,
 | 
			
		||||
        n_splits_down_proj: int = 1,
 | 
			
		||||
        group_size: int = 0
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.op_parameters = parameters
 | 
			
		||||
| 
						 | 
				
			
			@ -395,6 +418,9 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
            mode="prefill",
 | 
			
		||||
            transpose_value=self.transpose_value,
 | 
			
		||||
            dtype=np_dtype,
 | 
			
		||||
            n_splits_linear=n_splits_linear,
 | 
			
		||||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size
 | 
			
		||||
        )
 | 
			
		||||
        self.layer_norm_0 = layer_norm_0
 | 
			
		||||
        self.layer_norm_1 = layer_norm_1
 | 
			
		||||
| 
						 | 
				
			
			@ -474,24 +500,53 @@ def run_decode(
 | 
			
		|||
    head_dim = model.model.layers[layer_start].self_attn.head_dim
 | 
			
		||||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    intermediate_size = model.config.intermediate_size
 | 
			
		||||
    group_size = getattr(model.config, "group_size", 0)
 | 
			
		||||
    layer_weights = []
 | 
			
		||||
    input_layer_norm_weights = []
 | 
			
		||||
    post_attn_layernorm_weights = []
 | 
			
		||||
    layer_indexs = range(layer_start, layer_end)
 | 
			
		||||
    n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
			
		||||
    n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
			
		||||
    for layer_idx in layer_indexs:
 | 
			
		||||
        curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
        attn_layer = curr_layer.self_attn
 | 
			
		||||
        mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
        weights = [
 | 
			
		||||
            (attn_layer.q_proj.weight, attn_layer.q_proj.scale),
 | 
			
		||||
            (attn_layer.k_proj.weight, attn_layer.k_proj.scale),
 | 
			
		||||
            (attn_layer.v_proj.weight, attn_layer.v_proj.scale),
 | 
			
		||||
            (attn_layer.o_proj.weight, attn_layer.o_proj.scale),
 | 
			
		||||
            (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
 | 
			
		||||
            (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
			
		||||
            (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
 | 
			
		||||
        ]
 | 
			
		||||
        weights = []
 | 
			
		||||
        if n_splits_linear == 1:
 | 
			
		||||
            for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                        attn_layer.k_proj_dq_list,
 | 
			
		||||
                                        attn_layer.v_proj_dq_list,
 | 
			
		||||
                                        attn_layer.o_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.up_proj_dq_list):
 | 
			
		||||
                weights.append((q.weight, q.scale))
 | 
			
		||||
                weights.append((k.weight, k.scale))
 | 
			
		||||
                weights.append((v.weight, v.scale))
 | 
			
		||||
                weights.append((o.weight, o.scale))
 | 
			
		||||
                weights.append((g.weight, g.scale))
 | 
			
		||||
                weights.append((u.weight, u.scale))
 | 
			
		||||
        else:
 | 
			
		||||
            for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                               attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                               mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                l_weights = []
 | 
			
		||||
                scales = []
 | 
			
		||||
                for l in layer_list:
 | 
			
		||||
                    l_weights.append(l.weight)
 | 
			
		||||
                    scales.append(l.scale)
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        if n_splits_down_proj == 1:
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                weights.append((l.weight, l.scale))
 | 
			
		||||
        else:
 | 
			
		||||
            l_weights = []
 | 
			
		||||
            scales = []
 | 
			
		||||
            for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                l_weights.append(l.weight)
 | 
			
		||||
                scales.append(l.scale)
 | 
			
		||||
            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -518,6 +573,9 @@ def run_decode(
 | 
			
		|||
        max_seq_len=max_seq_len,
 | 
			
		||||
        transpose_value=transpose_value_cache,
 | 
			
		||||
        do_print=False,
 | 
			
		||||
        n_splits_linear=n_splits_linear,
 | 
			
		||||
        n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
        group_size=group_size
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    dist.barrier()
 | 
			
		||||
| 
						 | 
				
			
			@ -591,11 +649,15 @@ class DecodeRunner:
 | 
			
		|||
 | 
			
		||||
        self.forward_signal = torch.tensor(0, dtype=torch.int)
 | 
			
		||||
 | 
			
		||||
        n_layers_per_rank = num_layers // (world_size - 1)
 | 
			
		||||
        if num_layers % (world_size - 1) > 0:
 | 
			
		||||
            n_layers_per_rank += 1
 | 
			
		||||
 | 
			
		||||
        for rank in range(1, world_size):
 | 
			
		||||
            input_q = mp.Queue()
 | 
			
		||||
            output_q = mp.Queue()
 | 
			
		||||
            start_layer = (rank - 1) * (num_layers // (world_size - 1))
 | 
			
		||||
            end_layer = (rank) * (num_layers // (world_size - 1))
 | 
			
		||||
            start_layer = (rank - 1) * n_layers_per_rank
 | 
			
		||||
            end_layer = (rank) * n_layers_per_rank
 | 
			
		||||
            if rank == world_size - 1:
 | 
			
		||||
                end_layer = num_layers
 | 
			
		||||
            p = mp.Process(
 | 
			
		||||
| 
						 | 
				
			
			@ -676,25 +738,34 @@ def run_prefill(
 | 
			
		|||
    head_dim = model.model.layers[layer_start].self_attn.head_dim
 | 
			
		||||
    rms_norm_eps = model.config.rms_norm_eps
 | 
			
		||||
    intermediate_size = model.config.intermediate_size
 | 
			
		||||
    group_size = getattr(model.config, "group_size", 0)
 | 
			
		||||
    deocderlayers = []
 | 
			
		||||
    layer_weights = []
 | 
			
		||||
    input_layer_norm_weights = []
 | 
			
		||||
    post_attn_layernorm_weights = []
 | 
			
		||||
    layer_indexs = range(layer_start, layer_end)
 | 
			
		||||
    n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
			
		||||
    n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
			
		||||
    for layer_idx in layer_indexs:
 | 
			
		||||
        curr_layer = model.model.layers[layer_idx]
 | 
			
		||||
        attn_layer = curr_layer.self_attn
 | 
			
		||||
        mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
        weights = [
 | 
			
		||||
            (attn_layer.q_proj.weight, attn_layer.q_proj.scale),
 | 
			
		||||
            (attn_layer.k_proj.weight, attn_layer.k_proj.scale),
 | 
			
		||||
            (attn_layer.v_proj.weight, attn_layer.v_proj.scale),
 | 
			
		||||
            (attn_layer.o_proj.weight, attn_layer.o_proj.scale),
 | 
			
		||||
            (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
 | 
			
		||||
            (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
			
		||||
            (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
 | 
			
		||||
        ]
 | 
			
		||||
        weights = []
 | 
			
		||||
 | 
			
		||||
        for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                           attn_layer.v_proj_dq_list):
 | 
			
		||||
            weights.append((q.weight, q.scale))
 | 
			
		||||
            weights.append((k.weight, k.scale))
 | 
			
		||||
            weights.append((v.weight, v.scale))
 | 
			
		||||
 | 
			
		||||
        for l in attn_layer.o_proj_dq_list:
 | 
			
		||||
            weights.append((l.weight, l.scale))
 | 
			
		||||
        for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
 | 
			
		||||
            weights.append((g.weight, g.scale))
 | 
			
		||||
            weights.append((u.weight, u.scale))
 | 
			
		||||
        for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
            weights.append((l.weight, l.scale))
 | 
			
		||||
 | 
			
		||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -715,6 +786,9 @@ def run_prefill(
 | 
			
		|||
            intermediate_size=intermediate_size,
 | 
			
		||||
            max_seq_len=max_output_len,
 | 
			
		||||
            transpose_value=transpose_value_cache,
 | 
			
		||||
            n_splits_linear=n_splits_linear,
 | 
			
		||||
            n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
            group_size=group_size
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        layer_weights.extend(weights)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -42,27 +42,8 @@ from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
 | 
			
		|||
from ipex_llm.transformers.npu_models.common import reshape_lm_head_input
 | 
			
		||||
from transformers.modeling_outputs import CausalLMOutputWithPast
 | 
			
		||||
from torch.nn import CrossEntropyLoss
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2Attention
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.npu_models.common import split_linear
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down_proj=2):
 | 
			
		||||
    attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"]
 | 
			
		||||
    mlp_module_names = ["down_proj", "up_proj", "gate_proj"]
 | 
			
		||||
    if isinstance(module, Qwen2Attention):
 | 
			
		||||
        for name in attn_module_names:
 | 
			
		||||
            setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
 | 
			
		||||
                                                            n_splits=n_splits_hidden_size))
 | 
			
		||||
            delattr(module, name)
 | 
			
		||||
    elif isinstance(module, Qwen2MLP):
 | 
			
		||||
        for name in mlp_module_names:
 | 
			
		||||
            n_splits_mlp = n_splits_hidden_size
 | 
			
		||||
            if name == 'down_proj':
 | 
			
		||||
                n_splits_mlp = n_splits_down_proj
 | 
			
		||||
            setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
 | 
			
		||||
                                                            n_splits=n_splits_mlp))
 | 
			
		||||
            delattr(module, name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_mlp_down_proj(module: torch.nn.Module):
 | 
			
		||||
| 
						 | 
				
			
			@ -594,30 +575,22 @@ def run_decode(
 | 
			
		|||
 | 
			
		||||
        weights = []
 | 
			
		||||
        if n_splits_linear == 1:
 | 
			
		||||
            for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                               attn_layer.v_proj_dq_list):
 | 
			
		||||
            for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                        attn_layer.k_proj_dq_list,
 | 
			
		||||
                                        attn_layer.v_proj_dq_list,
 | 
			
		||||
                                        attn_layer.o_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                        mlp_layer.up_proj_dq_list):
 | 
			
		||||
                weights.append((q.weight, q.scale))
 | 
			
		||||
                weights.append((k.weight, k.scale))
 | 
			
		||||
                weights.append((v.weight, v.scale))
 | 
			
		||||
 | 
			
		||||
            for l in attn_layer.o_proj_dq_list:
 | 
			
		||||
                weights.append((l.weight, l.scale))
 | 
			
		||||
        else:
 | 
			
		||||
            for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                               attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list]:
 | 
			
		||||
                l_weights = []
 | 
			
		||||
                scales = []
 | 
			
		||||
                for l in layer_list:
 | 
			
		||||
                    l_weights.append(l.weight)
 | 
			
		||||
                    scales.append(l.scale)
 | 
			
		||||
                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
        if n_splits_linear == 1:
 | 
			
		||||
            for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
 | 
			
		||||
                weights.append((o.weight, o.scale))
 | 
			
		||||
                weights.append((g.weight, g.scale))
 | 
			
		||||
                weights.append((u.weight, u.scale))
 | 
			
		||||
        else:
 | 
			
		||||
            for layer_list in [mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
            for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                               attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                               mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                l_weights = []
 | 
			
		||||
                scales = []
 | 
			
		||||
                for l in layer_list:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue