[NPU] Support l0 Llama groupwise (#12276)
* except lm_head * remove * support gw lm_head * update * fix * remove run.bat * fix style * support llama3
This commit is contained in:
		
							parent
							
								
									1cef0c4948
								
							
						
					
					
						commit
						4467645088
					
				
					 5 changed files with 85 additions and 24 deletions
				
			
		| 
						 | 
				
			
			@ -52,6 +52,7 @@ if __name__ == "__main__":
 | 
			
		|||
                        help='Prompt to infer')
 | 
			
		||||
    parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
 | 
			
		||||
    parser.add_argument("--max-context-len", type=int, default=1024)
 | 
			
		||||
    parser.add_argument("--quantization_group_size", type=int, default=0)
 | 
			
		||||
    parser.add_argument("--max-prompt-len", type=int, default=960)
 | 
			
		||||
    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -63,6 +64,7 @@ if __name__ == "__main__":
 | 
			
		|||
                                                 pipeline=True,
 | 
			
		||||
                                                 max_context_len=args.max_context_len,
 | 
			
		||||
                                                 max_prompt_len=args.max_prompt_len,
 | 
			
		||||
                                                 quantization_group_size=args.quantization_group_size,
 | 
			
		||||
                                                 torch_dtype=torch.float16,
 | 
			
		||||
                                                 attn_implementation="eager",
 | 
			
		||||
                                                 transpose_value_cache=not args.disable_transpose_value_cache)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -59,6 +59,7 @@ if __name__ == "__main__":
 | 
			
		|||
    parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
 | 
			
		||||
    parser.add_argument("--max-context-len", type=int, default=1024)
 | 
			
		||||
    parser.add_argument("--max-prompt-len", type=int, default=960)
 | 
			
		||||
    parser.add_argument("--quantization_group_size", type=int, default=0)
 | 
			
		||||
    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
| 
						 | 
				
			
			@ -70,6 +71,7 @@ if __name__ == "__main__":
 | 
			
		|||
                                                 pipeline=True,
 | 
			
		||||
                                                 max_context_len=args.max_context_len,
 | 
			
		||||
                                                 max_prompt_len=args.max_prompt_len,
 | 
			
		||||
                                                 quantization_group_size=args.quantization_group_size,
 | 
			
		||||
                                                 attn_implementation="eager",
 | 
			
		||||
                                                 transpose_value_cache=not args.disable_transpose_value_cache)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -186,7 +186,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
                "max_prompt_len": max_prompt_len,
 | 
			
		||||
                "inter_pp": inter_pp,
 | 
			
		||||
                "intra_pp": intra_pp,
 | 
			
		||||
                "transpose_value_cache": transpose_value_cache,
 | 
			
		||||
                "transpose_value_cache": transpose_value_cache
 | 
			
		||||
            }
 | 
			
		||||
            model = cls.optimize_npu_model(*args, **optimize_kwargs)
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			@ -260,7 +260,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
            convert_llm(llm,
 | 
			
		||||
                        kv_len=max_context_len,
 | 
			
		||||
                        max_prompt_len=max_prompt_len,
 | 
			
		||||
                        transpose_value_cache=transpose_value_cache)
 | 
			
		||||
                        transpose_value_cache=transpose_value_cache,
 | 
			
		||||
                        group_size=quantization_group_size)
 | 
			
		||||
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,6 +30,7 @@ import threading
 | 
			
		|||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
import tempfile
 | 
			
		||||
import numpy as np
 | 
			
		||||
from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate(
 | 
			
		||||
| 
						 | 
				
			
			@ -225,7 +226,14 @@ def update_names_of_IR_and_export_blob(model, model_name, dir):
 | 
			
		|||
def convert_llm(model: torch.nn.Module,
 | 
			
		||||
                kv_len: int,
 | 
			
		||||
                max_prompt_len: int,
 | 
			
		||||
                transpose_value_cache: bool):
 | 
			
		||||
                transpose_value_cache: bool,
 | 
			
		||||
                group_size: int):
 | 
			
		||||
    if group_size == 0:
 | 
			
		||||
        n_splits_linear = 1
 | 
			
		||||
        n_splits_down_proj = 1
 | 
			
		||||
    else:
 | 
			
		||||
        n_splits_linear = model.config.hidden_size // group_size
 | 
			
		||||
        n_splits_down_proj = model.config.intermediate_size // group_size
 | 
			
		||||
    if model.config.model_type == "llama":
 | 
			
		||||
        from ipex_llm.transformers.npu_models.convert_mp import convert_llama
 | 
			
		||||
        convert_llama(model,
 | 
			
		||||
| 
						 | 
				
			
			@ -247,7 +255,17 @@ def convert_llm(model: torch.nn.Module,
 | 
			
		|||
            vocab_size = model.config.vocab_size
 | 
			
		||||
            model_norm = model.model.norm
 | 
			
		||||
            lm_head = model.lm_head
 | 
			
		||||
            weights = [(lm_head.weight, lm_head.scale)]
 | 
			
		||||
            if n_splits_linear == 1:
 | 
			
		||||
                weights = [(lm_head.weight, lm_head.scale)]
 | 
			
		||||
            else:
 | 
			
		||||
                lm_heads = lm_head.lm_heads
 | 
			
		||||
                lm_head_weights = []
 | 
			
		||||
                scales = []
 | 
			
		||||
                for i in range(n_splits_linear):
 | 
			
		||||
                    lm_head_weights.append(lm_heads[i].weight)
 | 
			
		||||
                    scales.append(lm_heads[i].scale)
 | 
			
		||||
                weights = [(torch.stack(lm_head_weights, axis=0),
 | 
			
		||||
                           torch.stack(scales, axis=0))]
 | 
			
		||||
            if isinstance(weights[0], tuple):
 | 
			
		||||
                np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
 | 
			
		||||
            else:  # FP16 Linear
 | 
			
		||||
| 
						 | 
				
			
			@ -264,13 +282,17 @@ def convert_llm(model: torch.nn.Module,
 | 
			
		|||
                dtype=np_dtype,
 | 
			
		||||
                model_norm_weight=model_norm.weight.to(torch.float16),
 | 
			
		||||
                vocab_size=vocab_size,
 | 
			
		||||
                n_splits=n_splits_linear
 | 
			
		||||
            )
 | 
			
		||||
            last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
 | 
			
		||||
 | 
			
		||||
            # save weights bins files
 | 
			
		||||
            weight_numpy = [
 | 
			
		||||
                lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
 | 
			
		||||
            ]
 | 
			
		||||
            if n_splits_linear == 1:
 | 
			
		||||
                weight_numpy = [
 | 
			
		||||
                    lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
 | 
			
		||||
                ]
 | 
			
		||||
            else:
 | 
			
		||||
                weight_numpy = [v.numpy() for v in weights[0]]
 | 
			
		||||
 | 
			
		||||
            for idx, weight in enumerate(weight_numpy):
 | 
			
		||||
                bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
 | 
			
		||||
| 
						 | 
				
			
			@ -295,20 +317,41 @@ def convert_llm(model: torch.nn.Module,
 | 
			
		|||
                mlp_layer = curr_layer.mlp
 | 
			
		||||
 | 
			
		||||
                weights = []
 | 
			
		||||
                for q, k, v, o, g, u, d in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                               attn_layer.k_proj_dq_list,
 | 
			
		||||
                                               attn_layer.v_proj_dq_list,
 | 
			
		||||
                                               attn_layer.o_proj_dq_list,
 | 
			
		||||
                                               mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                               mlp_layer.up_proj_dq_list,
 | 
			
		||||
                                               mlp_layer.down_proj_dq_list):
 | 
			
		||||
                    weights.append((q.weight, q.scale))
 | 
			
		||||
                    weights.append((k.weight, k.scale))
 | 
			
		||||
                    weights.append((v.weight, v.scale))
 | 
			
		||||
                    weights.append((o.weight, o.scale))
 | 
			
		||||
                    weights.append((g.weight, g.scale))
 | 
			
		||||
                    weights.append((u.weight, u.scale))
 | 
			
		||||
                    weights.append((d.weight, d.scale))
 | 
			
		||||
                if n_splits_linear == 1:
 | 
			
		||||
                    for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
 | 
			
		||||
                                                attn_layer.k_proj_dq_list,
 | 
			
		||||
                                                attn_layer.v_proj_dq_list,
 | 
			
		||||
                                                attn_layer.o_proj_dq_list,
 | 
			
		||||
                                                mlp_layer.gate_proj_dq_list,
 | 
			
		||||
                                                mlp_layer.up_proj_dq_list):
 | 
			
		||||
                        weights.append((q.weight, q.scale))
 | 
			
		||||
                        weights.append((k.weight, k.scale))
 | 
			
		||||
                        weights.append((v.weight, v.scale))
 | 
			
		||||
                        weights.append((o.weight, o.scale))
 | 
			
		||||
                        weights.append((g.weight, g.scale))
 | 
			
		||||
                        weights.append((u.weight, u.scale))
 | 
			
		||||
                else:
 | 
			
		||||
                    for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
			
		||||
                                       attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
 | 
			
		||||
                                       mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
			
		||||
                        l_weights = []
 | 
			
		||||
                        scales = []
 | 
			
		||||
                        for l in layer_list:
 | 
			
		||||
                            l_weights.append(l.weight)
 | 
			
		||||
                            scales.append(l.scale)
 | 
			
		||||
                        weights.append((torch.stack(l_weights, axis=0),
 | 
			
		||||
                                        torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
                if n_splits_down_proj == 1:
 | 
			
		||||
                    for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                        weights.append((l.weight, l.scale))
 | 
			
		||||
                else:
 | 
			
		||||
                    l_weights = []
 | 
			
		||||
                    scales = []
 | 
			
		||||
                    for l in mlp_layer.down_proj_dq_list:
 | 
			
		||||
                        l_weights.append(l.weight)
 | 
			
		||||
                        scales.append(l.scale)
 | 
			
		||||
                    weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
			
		||||
 | 
			
		||||
                cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
			
		||||
                cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
			
		||||
| 
						 | 
				
			
			@ -336,6 +379,9 @@ def convert_llm(model: torch.nn.Module,
 | 
			
		|||
                        mode="decode",
 | 
			
		||||
                        transpose_value=transpose_value_cache,
 | 
			
		||||
                        dtype=np_dtype,
 | 
			
		||||
                        n_splits_linear=n_splits_linear,
 | 
			
		||||
                        n_splits_down_proj=n_splits_down_proj,
 | 
			
		||||
                        group_size=group_size
 | 
			
		||||
                    )
 | 
			
		||||
                    rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
 | 
			
		||||
                                                                        "decoder_layer",
 | 
			
		||||
| 
						 | 
				
			
			@ -370,6 +416,9 @@ def convert_llm(model: torch.nn.Module,
 | 
			
		|||
        invalidInputError(False,
 | 
			
		||||
                          "Now we only support Llama2 for pipeline running.")
 | 
			
		||||
 | 
			
		||||
    if isinstance(model.lm_head, SlicedLMHead):
 | 
			
		||||
        model.lm_head.get_fused_lm_head()
 | 
			
		||||
 | 
			
		||||
    # patch generate function
 | 
			
		||||
    import types
 | 
			
		||||
    model.generate = types.MethodType(generate, model)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -36,6 +36,7 @@ class LowBitLlamaLMHead(LLMBaseNNFactory):
 | 
			
		|||
        transpose_value: bool = False,
 | 
			
		||||
        profile: bool = False,
 | 
			
		||||
        device: str = "NPU",
 | 
			
		||||
        n_splits: int = 1,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(max_seq_len=max_seq_len,
 | 
			
		||||
                         transpose_value=transpose_value,
 | 
			
		||||
| 
						 | 
				
			
			@ -64,9 +65,15 @@ class LowBitLlamaLMHead(LLMBaseNNFactory):
 | 
			
		|||
        # model norm and lm head
 | 
			
		||||
        model_norm_weight = self.constant(model_norm_weight)
 | 
			
		||||
        hidden_states = self.layer_norm(hidden_states, model_norm_weight)
 | 
			
		||||
        hidden_states = self.linear(
 | 
			
		||||
            hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
			
		||||
        )
 | 
			
		||||
        if n_splits == 1:
 | 
			
		||||
            hidden_states = self.linear(
 | 
			
		||||
                hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            hidden_states = self.dq_split_linear(
 | 
			
		||||
                hidden_states, self.vocab_size, self.hidden_size, n_splits,
 | 
			
		||||
                wt_dtype=dtype, scale_factor=False
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # define outputs
 | 
			
		||||
        hidden_states = self.convert_to_fp32(hidden_states)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue