diff --git a/python/llm/dev/benchmark/all-in-one/run.py b/python/llm/dev/benchmark/all-in-one/run.py index fdf982cb..bbe10c22 100644 --- a/python/llm/dev/benchmark/all-in-one/run.py +++ b/python/llm/dev/benchmark/all-in-one/run.py @@ -641,7 +641,7 @@ def transformers_int4_npu_win(repo_id, model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True, torch_dtype=torch.float16, optimize_model=optimize_model, max_context_len=max_context_len, max_prompt_len=int(in_out_len[0]), quantization_group_size=npu_group_size, transpose_value_cache=transpose_value_cache, - mixed_precision=True, save_directory=save_directory, use_cache=True, attn_implementation="eager").eval() + save_directory=save_directory, use_cache=True, attn_implementation="eager").eval() tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) end = time.perf_counter() load_time = end - st @@ -701,7 +701,6 @@ def transformers_int4_npu_pipeline_win(repo_id, model_path = get_model_path(repo_id, local_model_hub) in_out_len = in_out_pairs[0].split("-") max_context_len = max(int(in_out_len[0]) + int(in_out_len[1]), 1024) - mixed_precision = True if npu_group_size == 0 else False save_directory = "./save_converted_model_dir" # Load model in 4 bit, # which convert the relevant layers in the model into INT4 format @@ -710,7 +709,7 @@ def transformers_int4_npu_pipeline_win(repo_id, model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True, pipeline=True, torch_dtype=torch.float16, optimize_model=optimize_model, max_context_len=max_context_len, max_prompt_len=int(in_out_len[0]), quantization_group_size=npu_group_size, transpose_value_cache=transpose_value_cache, - use_cache=True, attn_implementation="eager", mixed_precision=mixed_precision, + use_cache=True, attn_implementation="eager", save_directory=save_directory).eval() tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/convert.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/convert.py index e0811a5c..e236433c 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/convert.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/convert.py @@ -68,7 +68,6 @@ if __name__ == "__main__": torch_dtype=torch.float16, attn_implementation="eager", transpose_value_cache=not args.disable_transpose_value_cache, - mixed_precision=True, trust_remote_code=True, convert_model=True, save_directory=save_dir) diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/qwen.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/qwen.py index e1f4be49..ca0475c7 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/qwen.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/qwen.py @@ -67,7 +67,6 @@ if __name__ == "__main__": torch_dtype=torch.float16, attn_implementation="eager", transpose_value_cache=not args.disable_transpose_value_cache, - mixed_precision=True, trust_remote_code=True, save_directory=args.save_directory) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/qwen.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/qwen.py index caf6d1b3..0b4c3b69 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/qwen.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/qwen.py @@ -67,7 +67,6 @@ if __name__ == "__main__": max_context_len=args.max_context_len, max_prompt_len=args.max_prompt_len, transpose_value_cache=not args.disable_transpose_value_cache, - mixed_precision=True, quantization_group_size=args.quantization_group_size, save_directory=args.save_directory ) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 2e98c1eb..39c9cd00 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -153,16 +153,19 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision, if model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40: # workaround for MiniCPM-2B new_lm_head_0 = SlicedLMHead(model.lm_head_0.weight, split_num=split_num, - bias=model.lm_head_0.bias, use_split=True) + bias=model.lm_head_0.bias, use_split=True, + group_size=quantization_group_size) del model.lm_head_0 model.lm_head_0 = new_lm_head_0 new_lm_head_1 = SlicedLMHead(model.lm_head_1.weight, split_num=split_num, - bias=model.lm_head_1.bias, use_split=True) + bias=model.lm_head_1.bias, use_split=True, + group_size=quantization_group_size) del model.lm_head_1 model.lm_head_1 = new_lm_head_1 else: new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num, - bias=model.lm_head.bias, use_split=True) + bias=model.lm_head.bias, use_split=True, + group_size=quantization_group_size) del model.lm_head model.lm_head = new_lm_head @@ -176,7 +179,8 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision, is_split = (not mixed_precision) and qtype == "sym_int4_rtn" split_num = 14 if is_split else 1 new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num, - bias=model.lm_head.bias, use_split=False) + bias=model.lm_head.bias, use_split=True, + group_size=quantization_group_size) del model.lm_head model.lm_head = new_lm_head diff --git a/python/llm/src/ipex_llm/transformers/npu_models/lm_head.py b/python/llm/src/ipex_llm/transformers/npu_models/lm_head.py index d422fe6c..f306ae0e 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/lm_head.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/lm_head.py @@ -35,6 +35,7 @@ class LMHeadLinear(NNFactory): device: str = "NPU", dtype: np.dtype = np.int8, use_split: bool = False, + group_size: int = 0, ): """Initialize the LMHeadLinear class. @@ -57,7 +58,7 @@ class LMHeadLinear(NNFactory): if use_split: input = self.parameter((1, self.batch, self.inC)) res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype, - scale_factor=False) + scale_factor=(group_size == 0)) else: input = self.parameter((self.batch, self.inC)) split_size = self.inC // split_num // 2 * 2 @@ -108,12 +109,13 @@ class LMHeadLinear(NNFactory): class SlicedLMHead(nn.Module): - def __init__(self, weight, bias, split_num, use_split=False): + def __init__(self, weight, bias, split_num, use_split=False, group_size=0): super().__init__() self.split_num = split_num self.outC, self.inC = weight.shape split_size = weight.size(1) // split_num // 2 * 2 self.lm_heads = nn.Sequential() + self.group_size = group_size for i in range(split_num): new_linear = torch.nn.Linear(0, 0, bias=False) start_idx = i * split_size @@ -159,7 +161,8 @@ class SlicedLMHead(nn.Module): def get_fused_lm_head(self): np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8 self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num, - False, "NPU", dtype=np_dtype, use_split=self.use_split) + False, "NPU", dtype=np_dtype, use_split=self.use_split, + group_size=self.group_size) if self.use_split: weights = [] scales = [] diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py index 6abe95bc..b3829947 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py @@ -85,6 +85,7 @@ class LowBitLLMLMHead(LLMBaseNNFactory): profile: bool = False, device: str = "NPU", n_splits: int = 1, + group_size: int = 0, ): super().__init__(max_seq_len=max_seq_len, transpose_value=transpose_value, @@ -117,7 +118,7 @@ class LowBitLLMLMHead(LLMBaseNNFactory): hidden_states = self.linear( hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype, n_splits=n_splits, - scale_factor=(n_splits == 1), + scale_factor=(group_size == 0), ) # define outputs diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index c299adff..2e6b249c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -355,9 +355,10 @@ def convert_llm(model: torch.nn.Module, os.mkdir(weight_dir) layer_num = len(model.model.layers) from .qwen import convert_qwen_layer, convert_lm_head_and_embedding - first_blob_path, last_blob_path = convert_lm_head_and_embedding(model, n_splits_linear, - temp_dir, weight_dir, - convert_model) + first_blob_path, last_blob_path = convert_lm_head_and_embedding(model, temp_dir, + weight_dir, + convert_model, + group_size=group_size) param_list = [] for layer_idx in range(0, layer_num): @@ -470,9 +471,8 @@ def convert_llm_for_deploy(model: torch.nn.Module, save_directory, weight_dir, transpose_value_cache, max_prompt_len, group_size, layernorm_const, "prefill") # save blob of lmhead and bin of embedding - convert_lm_head_and_embedding(model, n_splits_linear, - save_directory, weight_dir, - convert_model=True) + convert_lm_head_and_embedding(model, save_directory, weight_dir, + convert_model=True, group_size=group_size) elif model.config.model_type == "llama": embedding_post = False cos_sin_input = False diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py index 5760e7fb..e4b31824 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py @@ -22,14 +22,15 @@ from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLML from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead -def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, - convert_model=False): +def convert_lm_head_and_embedding(model, temp_dir, weight_dir, + convert_model=False, group_size=0): num_heads = model.model.layers[0].self_attn.num_heads head_dim = model.model.layers[0].self_attn.head_dim rms_norm_eps = model.config.rms_norm_eps vocab_size = model.config.vocab_size model_norm = model.model.norm lm_head = model.lm_head + lm_head_n_splits = 1 if not isinstance(lm_head, SlicedLMHead): weights = [(lm_head.weight, lm_head.scale)] else: @@ -41,6 +42,7 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, scales.append(l.scale) weights = [(torch.stack(lm_head_weights, axis=0), torch.stack(scales, axis=0))] + lm_head_n_splits = lm_head.split_num if isinstance(weights[0], tuple): np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 else: # FP16 Linear @@ -56,7 +58,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, dtype=np_dtype, model_norm_weight=model_norm.weight.to(torch.float16), vocab_size=vocab_size, - n_splits=n_splits_linear + n_splits=lm_head_n_splits, + group_size=group_size, ) last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, f"lm_head",