diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index c63eef64..9400f65e 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -414,7 +414,7 @@ class _BaseAutoModelClass: optimize_llm(model) with torch.no_grad(): cls.load_convert(qtype, model, quant_device, modules_to_not_convert, - *model_args, **kwargs) + quantization_group_size, *model_args, **kwargs) create_npu_kernels(model) if is_sharded: diff --git a/python/llm/src/ipex_llm/transformers/npu_models/common.py b/python/llm/src/ipex_llm/transformers/npu_models/common.py index d4592eac..92d48b0c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/common.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/common.py @@ -59,3 +59,23 @@ def split_linear(module, module_name, n_splits=2): new_linear.weight = torch.nn.Parameter(weight.contiguous(), requires_grad=False) linear_list.add_module(f"{module_name}_dq_{idx}", new_linear) return linear_list + + +def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down_proj=2): + from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2Attention + from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention + attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"] + mlp_module_names = ["down_proj", "up_proj", "gate_proj"] + if isinstance(module, (Qwen2Attention, LlamaAttention)): + for name in attn_module_names: + setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name, + n_splits=n_splits_hidden_size)) + delattr(module, name) + elif isinstance(module, (Qwen2MLP, LlamaMLP)): + for name in mlp_module_names: + n_splits_mlp = n_splits_hidden_size + if name == 'down_proj': + n_splits_mlp = n_splits_down_proj + setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name, + n_splits=n_splits_mlp)) + delattr(module, name) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index fb39f27a..24af4f1f 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -87,8 +87,8 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision, model.llm.config.model_type = "llama" model = model.llm - if model.config.model_type == "qwen2": - from ipex_llm.transformers.npu_models.qwen2_mp import split_linears + if model.config.model_type in ["qwen2", "llama"]: + from ipex_llm.transformers.npu_models.common import split_linears if quantization_group_size == 0: n_splits_linear = 1 @@ -108,15 +108,19 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision, model.apply(lambda m: split_linears(m, n_splits_hidden_size=n_splits_linear, n_splits_down_proj=n_splits_down_proj)) + if quantization_group_size != 0: + split_num = model.config.hidden_size // quantization_group_size + new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num, + bias=model.lm_head.bias, use_split=True) + del model.lm_head + model.lm_head = new_lm_head + + if model.config.model_type == "qwen2": # for Qwen2-7B-Insturct, divide lm_head into 14 parts if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \ not cpu_lm_head: # Do not split lm_head and use sym_int8 instead when mixed_precison is True - if quantization_group_size != 0: - split_num = model.config.hidden_size // quantization_group_size - new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num, - bias=model.lm_head.bias, use_split=True) - else: + if quantization_group_size == 0: # Do not split lm_head and use sym_int8 instead when mixed_precison is True is_split = (not mixed_precision) and qtype == "sym_int4_rtn" split_num = 14 if is_split else 1 @@ -163,7 +167,7 @@ def optimize_llm( if intra_pp is None: intra_pp = 2 if inter_pp is None: - inter_pp = 2 + inter_pp = 2 if group_size == 0 else 8 from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward from ipex_llm.transformers.npu_models.llama_mp import DecodeRunner, PrefillRunner @@ -226,11 +230,6 @@ def optimize_llm( from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward) - - # for Qwen2-7B-Insturct, divide lm_head into 14 parts - if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \ - isinstance(model.lm_head, SlicedLMHead): - model.lm_head.get_fused_lm_head() elif model.config.model_type == "minicpm": # for minicpm-1b if intra_pp is None: @@ -299,3 +298,6 @@ def optimize_llm( modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) convert_forward(model, module.BaichuanModel, baichuan_model_forward) + + if isinstance(model.lm_head, SlicedLMHead): + model.lm_head.get_fused_lm_head() diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index 6039d94d..d37d4623 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -67,12 +67,18 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): device: str = "NPU", rms_norm_eps, intermediate_size, + n_splits_linear: int = 1, + n_splits_down_proj: int = 1, + group_size: int = 0 ): super().__init__(max_seq_len=max_seq_len, transpose_value=transpose_value, dtype=dtype, profile=profile, - device=device) + device=device, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size) self.max_seq_len = max_seq_len self.intermediate_size = intermediate_size self.dtype = dtype @@ -215,7 +221,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): hidden_states = self.eltwise_add(residual, attn_output) residual = hidden_states hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight) - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states, self.seq_len, self.mode) hidden_states = self.eltwise_add(residual, hidden_states) hidden_states = self.convert_to_fp16(hidden_states) @@ -241,6 +247,9 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): max_seq_len: int = 1024, transpose_value: bool = False, do_print: bool = False, + n_splits_linear: int = 1, + n_splits_down_proj: int = 1, + group_size: int = 0 ): super().__init__() @@ -250,6 +259,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): for w in parameters: if isinstance(w, tuple): # from QuantizedLinear op_parameters.append((w[0].numpy(), w[1].numpy())) + elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight + op_parameters.append(w.numpy()) + elif isinstance(w, np.ndarray): # scale + op_parameters.append(w) else: op_parameters.append(w.to(torch.float16).numpy()) self.op_parameters = op_parameters @@ -258,6 +271,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): self.transpose_value = transpose_value if isinstance(parameters[0], tuple): np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 + elif parameters[0].dtype == torch.int8: + np_dtype = np.int8 + elif parameters[0].dtype == torch.uint8: + np_dtype = np.uint8 else: # FP16 Linear np_dtype = np.float16 @@ -292,6 +309,9 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): mode="decode", transpose_value=self.transpose_value, dtype=np_dtype, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) self.backend_decoders.append(decoder) @@ -367,6 +387,9 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): intermediate_size, max_seq_len: int = 128, transpose_value: bool = False, + n_splits_linear: int = 1, + n_splits_down_proj: int = 1, + group_size: int = 0 ): super().__init__() self.op_parameters = parameters @@ -395,6 +418,9 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): mode="prefill", transpose_value=self.transpose_value, dtype=np_dtype, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) self.layer_norm_0 = layer_norm_0 self.layer_norm_1 = layer_norm_1 @@ -474,24 +500,53 @@ def run_decode( head_dim = model.model.layers[layer_start].self_attn.head_dim rms_norm_eps = model.config.rms_norm_eps intermediate_size = model.config.intermediate_size + group_size = getattr(model.config, "group_size", 0) layer_weights = [] input_layer_norm_weights = [] post_attn_layernorm_weights = [] layer_indexs = range(layer_start, layer_end) + n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) + n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) for layer_idx in layer_indexs: curr_layer = model.model.layers[layer_idx] attn_layer = curr_layer.self_attn mlp_layer = curr_layer.mlp - weights = [ - (attn_layer.q_proj.weight, attn_layer.q_proj.scale), - (attn_layer.k_proj.weight, attn_layer.k_proj.scale), - (attn_layer.v_proj.weight, attn_layer.v_proj.scale), - (attn_layer.o_proj.weight, attn_layer.o_proj.scale), - (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), - (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), - (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), - ] + weights = [] + if n_splits_linear == 1: + for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, + attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, + attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, + mlp_layer.up_proj_dq_list): + weights.append((q.weight, q.scale)) + weights.append((k.weight, k.scale)) + weights.append((v.weight, v.scale)) + weights.append((o.weight, o.scale)) + weights.append((g.weight, g.scale)) + weights.append((u.weight, u.scale)) + else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + + if n_splits_down_proj == 1: + for l in mlp_layer.down_proj_dq_list: + weights.append((l.weight, l.scale)) + else: + l_weights = [] + scales = [] + for l in mlp_layer.down_proj_dq_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -518,6 +573,9 @@ def run_decode( max_seq_len=max_seq_len, transpose_value=transpose_value_cache, do_print=False, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) dist.barrier() @@ -591,11 +649,15 @@ class DecodeRunner: self.forward_signal = torch.tensor(0, dtype=torch.int) + n_layers_per_rank = num_layers // (world_size - 1) + if num_layers % (world_size - 1) > 0: + n_layers_per_rank += 1 + for rank in range(1, world_size): input_q = mp.Queue() output_q = mp.Queue() - start_layer = (rank - 1) * (num_layers // (world_size - 1)) - end_layer = (rank) * (num_layers // (world_size - 1)) + start_layer = (rank - 1) * n_layers_per_rank + end_layer = (rank) * n_layers_per_rank if rank == world_size - 1: end_layer = num_layers p = mp.Process( @@ -676,25 +738,34 @@ def run_prefill( head_dim = model.model.layers[layer_start].self_attn.head_dim rms_norm_eps = model.config.rms_norm_eps intermediate_size = model.config.intermediate_size + group_size = getattr(model.config, "group_size", 0) deocderlayers = [] layer_weights = [] input_layer_norm_weights = [] post_attn_layernorm_weights = [] layer_indexs = range(layer_start, layer_end) + n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) + n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) for layer_idx in layer_indexs: curr_layer = model.model.layers[layer_idx] attn_layer = curr_layer.self_attn mlp_layer = curr_layer.mlp - weights = [ - (attn_layer.q_proj.weight, attn_layer.q_proj.scale), - (attn_layer.k_proj.weight, attn_layer.k_proj.scale), - (attn_layer.v_proj.weight, attn_layer.v_proj.scale), - (attn_layer.o_proj.weight, attn_layer.o_proj.scale), - (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), - (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), - (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), - ] + weights = [] + + for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list): + weights.append((q.weight, q.scale)) + weights.append((k.weight, k.scale)) + weights.append((v.weight, v.scale)) + + for l in attn_layer.o_proj_dq_list: + weights.append((l.weight, l.scale)) + for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list): + weights.append((g.weight, g.scale)) + weights.append((u.weight, u.scale)) + for l in mlp_layer.down_proj_dq_list: + weights.append((l.weight, l.scale)) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -715,6 +786,9 @@ def run_prefill( intermediate_size=intermediate_size, max_seq_len=max_output_len, transpose_value=transpose_value_cache, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) layer_weights.extend(weights) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index e4092576..f6952af2 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -42,27 +42,8 @@ from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory from ipex_llm.transformers.npu_models.common import reshape_lm_head_input from transformers.modeling_outputs import CausalLMOutputWithPast from torch.nn import CrossEntropyLoss -from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2Attention +from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP from ipex_llm.utils.common.log4Error import invalidInputError -from ipex_llm.transformers.npu_models.common import split_linear - - -def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down_proj=2): - attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"] - mlp_module_names = ["down_proj", "up_proj", "gate_proj"] - if isinstance(module, Qwen2Attention): - for name in attn_module_names: - setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name, - n_splits=n_splits_hidden_size)) - delattr(module, name) - elif isinstance(module, Qwen2MLP): - for name in mlp_module_names: - n_splits_mlp = n_splits_hidden_size - if name == 'down_proj': - n_splits_mlp = n_splits_down_proj - setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name, - n_splits=n_splits_mlp)) - delattr(module, name) def split_mlp_down_proj(module: torch.nn.Module): @@ -594,30 +575,22 @@ def run_decode( weights = [] if n_splits_linear == 1: - for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list): + for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, + attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, + attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, + mlp_layer.up_proj_dq_list): weights.append((q.weight, q.scale)) weights.append((k.weight, k.scale)) weights.append((v.weight, v.scale)) - - for l in attn_layer.o_proj_dq_list: - weights.append((l.weight, l.scale)) - else: - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list]: - l_weights = [] - scales = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) - - if n_splits_linear == 1: - for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list): + weights.append((o.weight, o.scale)) weights.append((g.weight, g.scale)) weights.append((u.weight, u.scale)) else: - for layer_list in [mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: l_weights = [] scales = [] for l in layer_list: