From 446764508882b115d6a78d4abacffaa9ab7e4ad0 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Mon, 28 Oct 2024 11:06:55 +0200 Subject: [PATCH] [NPU] Support l0 Llama groupwise (#12276) * except lm_head * remove * support gw lm_head * update * fix * remove run.bat * fix style * support llama3 --- .../LLM/Pipeline-Models/llama2.py | 2 + .../LLM/Pipeline-Models/llama3.py | 2 + .../src/ipex_llm/transformers/npu_model.py | 5 +- .../npu_pipeline_model/convert_pipeline.py | 87 +++++++++++++++---- .../transformers/npu_pipeline_model/llama.py | 13 ++- 5 files changed, 85 insertions(+), 24 deletions(-) diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/llama2.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/llama2.py index 35d7826a..2d43c8ca 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/llama2.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/llama2.py @@ -52,6 +52,7 @@ if __name__ == "__main__": help='Prompt to infer') parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") parser.add_argument("--max-context-len", type=int, default=1024) + parser.add_argument("--quantization_group_size", type=int, default=0) parser.add_argument("--max-prompt-len", type=int, default=960) parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) @@ -63,6 +64,7 @@ if __name__ == "__main__": pipeline=True, max_context_len=args.max_context_len, max_prompt_len=args.max_prompt_len, + quantization_group_size=args.quantization_group_size, torch_dtype=torch.float16, attn_implementation="eager", transpose_value_cache=not args.disable_transpose_value_cache) diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/llama3.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/llama3.py index a3a8bf41..377cc17c 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/llama3.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/llama3.py @@ -59,6 +59,7 @@ if __name__ == "__main__": parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") parser.add_argument("--max-context-len", type=int, default=1024) parser.add_argument("--max-prompt-len", type=int, default=960) + parser.add_argument("--quantization_group_size", type=int, default=0) parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) args = parser.parse_args() @@ -70,6 +71,7 @@ if __name__ == "__main__": pipeline=True, max_context_len=args.max_context_len, max_prompt_len=args.max_prompt_len, + quantization_group_size=args.quantization_group_size, attn_implementation="eager", transpose_value_cache=not args.disable_transpose_value_cache) diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 1a855bb0..56a53a7e 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -186,7 +186,7 @@ class _BaseAutoModelClass: "max_prompt_len": max_prompt_len, "inter_pp": inter_pp, "intra_pp": intra_pp, - "transpose_value_cache": transpose_value_cache, + "transpose_value_cache": transpose_value_cache } model = cls.optimize_npu_model(*args, **optimize_kwargs) else: @@ -260,7 +260,8 @@ class _BaseAutoModelClass: convert_llm(llm, kv_len=max_context_len, max_prompt_len=max_prompt_len, - transpose_value_cache=transpose_value_cache) + transpose_value_cache=transpose_value_cache, + group_size=quantization_group_size) return model diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 34407b93..31ff054c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -30,6 +30,7 @@ import threading from ipex_llm.utils.common import invalidInputError import tempfile import numpy as np +from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead def generate( @@ -225,7 +226,14 @@ def update_names_of_IR_and_export_blob(model, model_name, dir): def convert_llm(model: torch.nn.Module, kv_len: int, max_prompt_len: int, - transpose_value_cache: bool): + transpose_value_cache: bool, + group_size: int): + if group_size == 0: + n_splits_linear = 1 + n_splits_down_proj = 1 + else: + n_splits_linear = model.config.hidden_size // group_size + n_splits_down_proj = model.config.intermediate_size // group_size if model.config.model_type == "llama": from ipex_llm.transformers.npu_models.convert_mp import convert_llama convert_llama(model, @@ -247,7 +255,17 @@ def convert_llm(model: torch.nn.Module, vocab_size = model.config.vocab_size model_norm = model.model.norm lm_head = model.lm_head - weights = [(lm_head.weight, lm_head.scale)] + if n_splits_linear == 1: + weights = [(lm_head.weight, lm_head.scale)] + else: + lm_heads = lm_head.lm_heads + lm_head_weights = [] + scales = [] + for i in range(n_splits_linear): + lm_head_weights.append(lm_heads[i].weight) + scales.append(lm_heads[i].scale) + weights = [(torch.stack(lm_head_weights, axis=0), + torch.stack(scales, axis=0))] if isinstance(weights[0], tuple): np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 else: # FP16 Linear @@ -264,13 +282,17 @@ def convert_llm(model: torch.nn.Module, dtype=np_dtype, model_norm_weight=model_norm.weight.to(torch.float16), vocab_size=vocab_size, + n_splits=n_splits_linear ) last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir) # save weights bins files - weight_numpy = [ - lm_head.weight.data.numpy(), lm_head.scale.data.numpy(), - ] + if n_splits_linear == 1: + weight_numpy = [ + lm_head.weight.data.numpy(), lm_head.scale.data.numpy(), + ] + else: + weight_numpy = [v.numpy() for v in weights[0]] for idx, weight in enumerate(weight_numpy): bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin") @@ -295,20 +317,41 @@ def convert_llm(model: torch.nn.Module, mlp_layer = curr_layer.mlp weights = [] - for q, k, v, o, g, u, d in zip(attn_layer.q_proj_dq_list, - attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, - attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, - mlp_layer.up_proj_dq_list, - mlp_layer.down_proj_dq_list): - weights.append((q.weight, q.scale)) - weights.append((k.weight, k.scale)) - weights.append((v.weight, v.scale)) - weights.append((o.weight, o.scale)) - weights.append((g.weight, g.scale)) - weights.append((u.weight, u.scale)) - weights.append((d.weight, d.scale)) + if n_splits_linear == 1: + for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, + attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, + attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, + mlp_layer.up_proj_dq_list): + weights.append((q.weight, q.scale)) + weights.append((k.weight, k.scale)) + weights.append((v.weight, v.scale)) + weights.append((o.weight, o.scale)) + weights.append((g.weight, g.scale)) + weights.append((u.weight, u.scale)) + else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), + torch.stack(scales, axis=0))) + + if n_splits_down_proj == 1: + for l in mlp_layer.down_proj_dq_list: + weights.append((l.weight, l.scale)) + else: + l_weights = [] + scales = [] + for l in mlp_layer.down_proj_dq_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -336,6 +379,9 @@ def convert_llm(model: torch.nn.Module, mode="decode", transpose_value=transpose_value_cache, dtype=np_dtype, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, "decoder_layer", @@ -370,6 +416,9 @@ def convert_llm(model: torch.nn.Module, invalidInputError(False, "Now we only support Llama2 for pipeline running.") + if isinstance(model.lm_head, SlicedLMHead): + model.lm_head.get_fused_lm_head() + # patch generate function import types model.generate = types.MethodType(generate, model) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py index 9ad6acc1..ba88ffef 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py @@ -36,6 +36,7 @@ class LowBitLlamaLMHead(LLMBaseNNFactory): transpose_value: bool = False, profile: bool = False, device: str = "NPU", + n_splits: int = 1, ): super().__init__(max_seq_len=max_seq_len, transpose_value=transpose_value, @@ -64,9 +65,15 @@ class LowBitLlamaLMHead(LLMBaseNNFactory): # model norm and lm head model_norm_weight = self.constant(model_norm_weight) hidden_states = self.layer_norm(hidden_states, model_norm_weight) - hidden_states = self.linear( - hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype - ) + if n_splits == 1: + hidden_states = self.linear( + hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype + ) + else: + hidden_states = self.dq_split_linear( + hidden_states, self.vocab_size, self.hidden_size, n_splits, + wt_dtype=dtype, scale_factor=False + ) # define outputs hidden_states = self.convert_to_fp32(hidden_states)