diff --git a/python/llm/dev/benchmark/all-in-one/save_npu.py b/python/llm/dev/benchmark/all-in-one/save_npu.py index 3270ee99..7bffa7a4 100644 --- a/python/llm/dev/benchmark/all-in-one/save_npu.py +++ b/python/llm/dev/benchmark/all-in-one/save_npu.py @@ -30,7 +30,9 @@ current_dir = os.path.dirname(os.path.realpath(__file__)) def save_npu_model_in_low_bit(repo_id, local_model_hub, low_bit, - max_output_len, max_prompt_len, intra_pp, inter_pp, disable_transpose_value_cache): + max_output_len, max_prompt_len, intra_pp, inter_pp, + disable_transpose_value_cache, + quantization_group_size): model_path = get_model_path(repo_id, local_model_hub) # Load model in 4 bit, # which convert the relevant layers in the model into INT4 format @@ -47,6 +49,7 @@ def save_npu_model_in_low_bit(repo_id, intra_pp=intra_pp, inter_pp=inter_pp, transpose_value_cache=not disable_transpose_value_cache, + quantization_group_size=quantization_group_size ) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) end = time.perf_counter() @@ -54,6 +57,7 @@ def save_npu_model_in_low_bit(repo_id, model.save_low_bit(model_path+'-npu-'+low_bit) tokenizer.save_pretrained(model_path+'-npu-'+low_bit) + print(f"Model saved to {model_path+'-npu-'+low_bit}") if __name__ == "__main__": @@ -65,6 +69,7 @@ if __name__ == "__main__": parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--intra-pp", type=int, default=2) parser.add_argument("--inter-pp", type=int, default=2) + parser.add_argument("--quantization_group_size", type=int, default=0) args = parser.parse_args() from omegaconf import OmegaConf @@ -78,5 +83,6 @@ if __name__ == "__main__": max_prompt_len=args.max_prompt_len, intra_pp=args.intra_pp, inter_pp=args.inter_pp, - disable_transpose_value_cache=args.disable_transpose_value_cache + disable_transpose_value_cache=args.disable_transpose_value_cache, + quantization_group_size=args.quantization_group_size, ) diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 56ca664c..c9936f25 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -81,6 +81,8 @@ class _BaseAutoModelClass: :param mixed_precision: boolean value, Whether to use mixed precision quantization. Default to be False. If set to ``True``, we will use ``'sym_int8'`` for lm_head when ``load_in_low_bit`` is '``sym_int4``' for certain models. + :param quantization_group_size: int, quantization group size, The recommended + quantization_group_size are 0, 32, 64 or 128 :return: a model instance """ if kwargs.get("device_map", None) not in [None, "cpu", "auto"]: @@ -126,6 +128,15 @@ class _BaseAutoModelClass: transpose_value_cache = kwargs.pop("transpose_value_cache", True) modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) mixed_precision = kwargs.pop('mixed_precision', False) + quantization_group_size = kwargs.pop("quantization_group_size", 0) + + invalidInputError( + quantization_group_size in [0, 32, 64, 128], + ( + "The recommended quantization_group_size are 0, 32, 64 or 128," + f"but got {quantization_group_size}" + ) + ) _args = copy.deepcopy(args) _kwargs = copy.deepcopy(kwargs) @@ -162,8 +173,11 @@ class _BaseAutoModelClass: with torch.no_grad(): model.config.update({"mixed_precision": mixed_precision}) - optimize_llm_pre(model, qtype, mixed_precision) - cls.load_convert(qtype, model, "cpu", modules_to_not_convert, *args, **kwargs) + model.config.update({"group_size": quantization_group_size}) + optimize_llm_pre(model, qtype, mixed_precision, + quantization_group_size=quantization_group_size) + cls.load_convert(qtype, model, "cpu", modules_to_not_convert, + quantization_group_size, *args, **kwargs) create_npu_kernels(llm) model = model.eval() logger.info(f"Finish to convert model") @@ -177,6 +191,7 @@ class _BaseAutoModelClass: inter_pp=inter_pp, intra_pp=intra_pp, transpose_value_cache=transpose_value_cache, + group_size=quantization_group_size ) model.save_low_bit = types.MethodType(save_low_bit, model) else: @@ -197,11 +212,13 @@ class _BaseAutoModelClass: return model @classmethod - def load_convert(cls, q_k, optimize_model, device, modules_to_not_convert, *arg, **kwarg): + def load_convert(cls, q_k, optimize_model, device, modules_to_not_convert, + group_size=0, *arg, **kwarg): from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear replace_with_QuantizedLinear(optimize_model, q_k, device=device, - modules_to_not_convert=modules_to_not_convert) + modules_to_not_convert=modules_to_not_convert, + group_size=group_size) @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) @@ -214,6 +231,7 @@ class _BaseAutoModelClass: ignore_argument(kwargs, "speculative") ignore_argument(kwargs, "pipeline_parallel_stages") ignore_argument(kwargs, "mixed_precision") + ignore_argument(kwargs, "quantization_group_size") optimize_model = kwargs.pop("optimize_model", False) max_output_len = kwargs.pop("max_output_len", 1024) max_prompt_len = kwargs.pop("max_prompt_len", 512) @@ -264,6 +282,7 @@ class _BaseAutoModelClass: qtype = config_dict.pop("bigdl_transformers_low_bit", False) bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True) mixed_precision = config_dict.pop("mixed_precision", False) + quantization_group_size = config_dict.pop("group_size", 0) invalidInputError( qtype, @@ -376,9 +395,10 @@ class _BaseAutoModelClass: llm = model with torch.no_grad(): - optimize_llm_pre(model, qtype, mixed_precision) + optimize_llm_pre(model, qtype, mixed_precision, + quantization_group_size=quantization_group_size) cls.load_convert(qtype, model, quant_device, modules_to_not_convert, - *model_args, **kwargs) + quantization_group_size, *model_args, **kwargs) create_npu_kernels(llm) else: @@ -458,6 +478,7 @@ class _BaseAutoModelClass: inter_pp=inter_pp, intra_pp=intra_pp, transpose_value_cache=transpose_value_cache, + group_size=quantization_group_size ) return model diff --git a/python/llm/src/ipex_llm/transformers/npu_models/common.py b/python/llm/src/ipex_llm/transformers/npu_models/common.py index 32841838..d4592eac 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/common.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/common.py @@ -16,6 +16,7 @@ import torch from typing import List +from ipex_llm.utils.common.log4Error import invalidInputError def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear: @@ -40,3 +41,21 @@ def reshape_lm_head_input(x): shape[1] = 1 x = x[:, -1, :].view(shape) return x + + +def split_linear(module, module_name, n_splits=2): + in_features = module.in_features + invalidInputError(in_features % n_splits == 0, + f"in_features of the linear layer {module_name} must be divisible by" + f" n_splits, but got in_features: {in_features}, n_splits: {n_splits}") + weight_split = torch.tensor_split(module.weight, n_splits, dim=1) + linear_list = torch.nn.ModuleList() + bias = module.bias + for idx, weight in enumerate(weight_split): + new_linear = torch.nn.Linear(weight.size(1), + weight.size(0), + bias=False if bias is None else True) + new_linear.bias = bias + new_linear.weight = torch.nn.Parameter(weight.contiguous(), requires_grad=False) + linear_list.add_module(f"{module_name}_dq_{idx}", new_linear) + return linear_list diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index d2df2977..a6b7a1cb 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -31,7 +31,8 @@ def module_optimization(func) -> torch.nn.Module: torch.nn.Module: optimized module """ - def wrapper(model: torch.nn.Module, qtype, device, modules_to_not_convert, *args, **kwargs): + def wrapper(model: torch.nn.Module, qtype, device, modules_to_not_convert, + group_size=0, *args, **kwargs): """Recursively apply the optimization function. Args: @@ -42,18 +43,22 @@ def module_optimization(func) -> torch.nn.Module: """ for name, layer in model.named_children(): if name not in modules_to_not_convert: - new_layer = func(layer, qtype, device, modules_to_not_convert, *args, **kwargs) + new_layer = func(layer, qtype, device, modules_to_not_convert, + group_size=group_size, *args, **kwargs) if new_layer: model.add_module(name, new_layer) - wrapper(new_layer, qtype, device, modules_to_not_convert, *args, **kwargs) + wrapper(new_layer, qtype, device, modules_to_not_convert, + group_size=group_size, *args, **kwargs) else: - wrapper(layer, qtype, device, modules_to_not_convert, *args, **kwargs) + wrapper(layer, qtype, device, modules_to_not_convert, + group_size=group_size, *args, **kwargs) return wrapper @module_optimization -def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert): +def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, + group_size): from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype from ipex_llm.ggml.quantize import ggml_tensor_qtype iqtype = ggml_tensor_qtype[qtype] @@ -66,7 +71,8 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert): iqtype = ggml_tensor_qtype[qtype] qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32), iqtype, device=device) - return QuantizedLinear(qweights, scale, layer.bias) + return QuantizedLinear(qweights, scale, layer.bias, + group_size=group_size) def convert_forward(m, target_m, new_forward): diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 47c94782..fb39f27a 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -19,6 +19,7 @@ import importlib import numpy as np from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params from ipex_llm.transformers.npu_models.lm_head import LMHeadLinear, SlicedLMHead +from ipex_llm.utils.common.log4Error import invalidInputError def convert_forward(m, target_m, new_forward): @@ -29,7 +30,8 @@ def convert_forward(m, target_m, new_forward): convert_forward(sub_m, target_m, new_forward) -def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision): +def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision, + quantization_group_size=0): if model.config.model_type == "baichuan": # process NormHead module in Baichuan2 7B if hasattr(model, 'lm_head') and model.lm_head is not None: @@ -86,17 +88,40 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision): model = model.llm if model.config.model_type == "qwen2": - from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj - model.apply(split_mlp_down_proj) + from ipex_llm.transformers.npu_models.qwen2_mp import split_linears + + if quantization_group_size == 0: + n_splits_linear = 1 + n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1 + else: + invalidInputError( + model.config.hidden_size % quantization_group_size == 0 and + model.config.intermediate_size % quantization_group_size == 0, + "The model hidden_size and intermediate_size should be divisible by " + f"quantization_group_size, but got hidden_size: {model.config.hidden_size}, " + f"intermediate_size: {model.config.intermediate_size}, and " + f"quantization_group_size: {quantization_group_size}" + ) + n_splits_linear = model.config.hidden_size // quantization_group_size + n_splits_down_proj = model.config.intermediate_size // quantization_group_size + + model.apply(lambda m: split_linears(m, n_splits_hidden_size=n_splits_linear, + n_splits_down_proj=n_splits_down_proj)) # for Qwen2-7B-Insturct, divide lm_head into 14 parts if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \ not cpu_lm_head: # Do not split lm_head and use sym_int8 instead when mixed_precison is True - is_split = (not mixed_precision) and qtype == "sym_int4_rtn" - split_num = 14 if is_split else 1 - new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num, - bias=model.lm_head.bias) + if quantization_group_size != 0: + split_num = model.config.hidden_size // quantization_group_size + new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num, + bias=model.lm_head.bias, use_split=True) + else: + # Do not split lm_head and use sym_int8 instead when mixed_precison is True + is_split = (not mixed_precision) and qtype == "sym_int4_rtn" + split_num = 14 if is_split else 1 + new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num, + bias=model.lm_head.bias, use_split=False) del model.lm_head model.lm_head = new_lm_head @@ -132,6 +157,7 @@ def optimize_llm( inter_pp=None, intra_pp=None, transpose_value_cache=True, + group_size=0 ): if model.config.model_type == "llama": if intra_pp is None: @@ -168,7 +194,13 @@ def optimize_llm( if intra_pp is None: intra_pp = 2 if inter_pp is None: - inter_pp = 2 if model.config.intermediate_size == 18944 else 1 + if model.config.intermediate_size == 18944: + if group_size != 0: + inter_pp = 5 + else: + inter_pp = 2 + else: + inter_pp = 1 from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner diff --git a/python/llm/src/ipex_llm/transformers/npu_models/linear.py b/python/llm/src/ipex_llm/transformers/npu_models/linear.py index 804751d2..d419da30 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/linear.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/linear.py @@ -130,6 +130,7 @@ class QuantizedLinear(torch.nn.Module): weight: torch.Tensor, scale: torch.Tensor, bias: Optional[torch.Tensor] = None, + group_size: int = False, ): """Initialize the QuantizedLinear class. @@ -154,10 +155,13 @@ class QuantizedLinear(torch.nn.Module): ) ) self.outC, self.inC = self.weight.shape - if self.weight.dtype == torch.uint8: - # In case is Int4 we need to double the input channels because weights are compressed - self.inC *= 2 - self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False) + if group_size != 0: + self.scale = Parameter(scale, requires_grad=False) + else: + if self.weight.dtype == torch.uint8: + # Int4 we need to double the input channels because weights are compressed + self.inC *= 2 + self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False) self.bias = bias self.op_id = str(uuid.uuid4()) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/lm_head.py b/python/llm/src/ipex_llm/transformers/npu_models/lm_head.py index 3dc05b6a..d422fe6c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/lm_head.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/lm_head.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. - import torch from torch import nn import numpy as np +from filelock import FileLock from intel_npu_acceleration_library.backend import NNFactory from intel_npu_acceleration_library.backend.bindings import lib as backend_lib @@ -34,6 +34,7 @@ class LMHeadLinear(NNFactory): profile: bool = False, device: str = "NPU", dtype: np.dtype = np.int8, + use_split: bool = False, ): """Initialize the LMHeadLinear class. @@ -51,26 +52,40 @@ class LMHeadLinear(NNFactory): self.inC, self.outC = inC, outC self.batch = batch - input = self.parameter((self.batch, self.inC)) - self.split_num = split_num - split_size = self.inC // split_num // 2 * 2 - for i in range(self.split_num): - start_idx = i * split_size - end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC - input_slice = self.slice(input, begin=[0, start_idx], - end=[self.batch, end_idx]) - linear_slice = self.linear(input_slice, outC, split_size, bias=False, wt_dtype=dtype) - if i == 0: - res = linear_slice - else: - res += linear_slice + if use_split: + input = self.parameter((1, self.batch, self.inC)) + res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype, + scale_factor=False) + else: + input = self.parameter((self.batch, self.inC)) + split_size = self.inC // split_num // 2 * 2 + + for i in range(self.split_num): + start_idx = i * split_size + end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC + input_slice = self.slice(input, begin=[0, start_idx], + end=[self.batch, end_idx]) + linear_slice = self.linear(input_slice, outC, split_size, bias=False, + wt_dtype=dtype) + if i == 0: + res = linear_slice + else: + res += linear_slice print("start compiling lm_head") self.compile() print("end compiling lm_head") + def set_weights(self, op_id, weights): + self.set_weights_async(op_id, weights) + with FileLock(f"lmhead_run.lock"): + backend_lib.run(self._mm) + + def set_weights_async(self, op_id, weights): + self.setWeights(1, op_id, *weights) + def run( self, X: np.ndarray ) -> np.ndarray: @@ -93,7 +108,7 @@ class LMHeadLinear(NNFactory): class SlicedLMHead(nn.Module): - def __init__(self, weight, bias, split_num): + def __init__(self, weight, bias, split_num, use_split=False): super().__init__() self.split_num = split_num self.outC, self.inC = weight.shape @@ -110,6 +125,7 @@ class SlicedLMHead(nn.Module): new_linear.out_features = new_weight.size(0) self.lm_heads.append(new_linear) self.bias = bias + self.use_split = use_split def forward(self, hidden_states): if hidden_states.size(0) * hidden_states.size(1) == 1: @@ -143,9 +159,19 @@ class SlicedLMHead(nn.Module): def get_fused_lm_head(self): np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8 self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num, - False, "NPU", dtype=np_dtype) - fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(), - self.lm_heads[i].scale.data.numpy()) - for i in range(self.split_num)] - self.fused_lm_head.setWeights(1, self.lm_heads[0].op_id, - *fused_lm_head_weights) + False, "NPU", dtype=np_dtype, use_split=self.use_split) + if self.use_split: + weights = [] + scales = [] + for i in range(self.split_num): + weights.append(self.lm_heads[i].weight) + scales.append(self.lm_heads[i].scale) + fused_lm_head_weights = (torch.stack(weights, axis=0).numpy(), + torch.stack(scales, axis=0).numpy()) + else: + fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(), + self.lm_heads[i].scale.data.numpy()) + for i in range(self.split_num)] + + self.fused_lm_head.set_weights(self.lm_heads[0].op_id, + fused_lm_head_weights) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 6f959aea..0080b40a 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -27,6 +27,8 @@ from filelock import FileLock import ctypes import math import numpy as np +from typing import Optional, Any, List +import numpy.typing as npt logger = logging.get_logger(__name__) @@ -60,6 +62,12 @@ def run_model( op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy())) op_args_flatten.append(op_args[-1][0]) op_args_flatten.append(op_args[-1][1]) + elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight + op_args.append(w.numpy()) + op_args_flatten.append(op_args[-1]) + elif isinstance(w, np.ndarray): # scale + op_args.append(w) + op_args_flatten.append(op_args[-1]) else: op_args.append(set_contiguous(w).to(torch.float16).numpy()) op_args_flatten.append(op_args[-1]) @@ -94,7 +102,8 @@ def run_model( class LLMBaseNNFactory(NNFactory): - def __init__(self, max_seq_len, transpose_value, dtype, profile=False, device="NPU"): + def __init__(self, max_seq_len, transpose_value, dtype, profile=False, device="NPU", + n_splits_linear=1, n_splits_down_proj=1, group_size=False): super().__init__(profile, device) self.cache_parameter_ops = [] self.input_ops = [] @@ -104,6 +113,9 @@ class LLMBaseNNFactory(NNFactory): self.max_seq_len = max_seq_len self.transpose_value = transpose_value self.dtype = dtype + self.n_splits_linear = n_splits_linear + self.n_splits_down_proj = n_splits_down_proj + self.group_size = group_size def attention(self, *, @@ -124,31 +136,92 @@ class LLMBaseNNFactory(NNFactory): v_bias=None): hidden_size = num_heads * head_dim num_key_value_groups = num_heads // num_key_value_heads - query_states = self.linear( - hidden_states, - num_heads * head_dim, - hidden_size, - bias=False, - wt_dtype=self.dtype, - ) + groupsize = hidden_size // self.n_splits_linear + if self.n_splits_linear == 1: + query_states = self.linear( + hidden_states, + num_heads * head_dim, + hidden_size, + bias=False, + wt_dtype=self.dtype, + ) + + key_states = self.linear( + hidden_states, + num_key_value_heads * head_dim, + hidden_size, + bias=False, + wt_dtype=self.dtype, + ) + + value_states = self.linear( + hidden_states, + num_key_value_heads * head_dim, + hidden_size, + bias=False, + wt_dtype=self.dtype, + ) + else: + hidden_states = self.unsqueeze(hidden_states, axis=0) + if mode == "prefill": + query_states_to_concat = [] + key_states_to_concat = [] + value_states_to_concat = [] + for i in range(self.n_splits_linear): + sub_hidden_states = self.slice(hidden_states, + begin=[0, 0, i * groupsize], + end=[1, seq_len, (i + 1) * groupsize]) + query_states_to_concat.append( + self.linear( + sub_hidden_states, + num_heads * head_dim, + groupsize, + bias=False, + wt_dtype=self.dtype, + scale_factor=(self.group_size == 0) + ) + ) + key_states_to_concat.append( + self.linear( + sub_hidden_states, + num_key_value_heads * head_dim, + groupsize, + bias=False, + wt_dtype=self.dtype, + scale_factor=(self.group_size == 0) + ) + ) + value_states_to_concat.append( + self.linear( + sub_hidden_states, + num_key_value_heads * head_dim, + groupsize, + bias=False, + wt_dtype=self.dtype, + scale_factor=(self.group_size == 0) + ) + ) + query_states = sum(query_states_to_concat) + key_states = sum(key_states_to_concat) + value_states = sum(value_states_to_concat) + else: + query_states = self.dq_split_linear(hidden_states, num_heads * head_dim, + hidden_size, self.n_splits_linear, + wt_dtype=self.dtype, + scale_factor=(self.group_size == 0)) + key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim, + hidden_size, self.n_splits_linear, + wt_dtype=self.dtype, + scale_factor=(self.group_size == 0)) + value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim, + hidden_size, self.n_splits_linear, + wt_dtype=self.dtype, + scale_factor=(self.group_size == 0)) + if q_bias is not None: query_states = query_states + q_bias - key_states = self.linear( - hidden_states, - num_key_value_heads * head_dim, - hidden_size, - bias=False, - wt_dtype=self.dtype, - ) if k_bias is not None: key_states = key_states + k_bias - value_states = self.linear( - hidden_states, - num_key_value_heads * head_dim, - hidden_size, - bias=False, - wt_dtype=self.dtype, - ) if v_bias is not None: value_states = value_states + v_bias @@ -215,23 +288,100 @@ class LLMBaseNNFactory(NNFactory): attn_output = self.transpose(attn_output, [0, 2, 1, 3]) attn_output = self.reshape(attn_output, [1, seq_len, hidden_size]) - attn_output = self.linear( - attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype - ) + if self.n_splits_linear == 1: + attn_output = self.linear( + attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype + ) + else: + if mode == "prefill": + attn_output_to_concat = [] + for i in range(self.n_splits_linear): + sub_attn_output = self.slice(attn_output, + begin=[0, 0, i * groupsize], + end=[1, seq_len, (i + 1) * groupsize]) + attn_output_to_concat.append( + self.linear( + sub_attn_output, hidden_size, groupsize, bias=False, + wt_dtype=self.dtype, scale_factor=(self.group_size == 0) + ) + ) + attn_output = sum(attn_output_to_concat) + else: + attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size, + self.n_splits_linear, wt_dtype=self.dtype, + scale_factor=(self.group_size == 0)) return attn_output, new_key_states, new_value_states - def mlp(self, hidden_states): - mm1 = self.linear( - hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype - ) - mm2 = self.linear( - hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype - ) # type: ignore[attr-defined] - mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] - hidden_states = self.linear( - mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype - ) + def mlp(self, hidden_states, seq_len=-1, mode="prefill"): + if self.n_splits_linear == 1: + mm1 = self.linear( + hidden_states, self.intermediate_size, self.hidden_size, bias=False, + wt_dtype=self.dtype + ) + mm2 = self.linear( + hidden_states, self.intermediate_size, self.hidden_size, bias=False, + wt_dtype=self.dtype + ) # type: ignore[attr-defined] + mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] + else: + invalidInputError(seq_len > 0, "seq_len should be provided if use split linear") + if mode == "prefill": + gate_up_groupsize = self.hidden_size // self.n_splits_linear + mm1_to_concat = [] + mm2_to_concat = [] + for i in range(self.n_splits_linear): + sub_hidden_states = self.slice(hidden_states, + begin=[0, 0, i * gate_up_groupsize], + end=[1, seq_len, (i + 1) * gate_up_groupsize]) + mm1_to_concat.append( + self.linear( + sub_hidden_states, self.intermediate_size, gate_up_groupsize, + bias=False, + wt_dtype=self.dtype, scale_factor=(self.group_size == 0) + ) + ) + mm2_to_concat.append( + self.linear( + sub_hidden_states, self.intermediate_size, gate_up_groupsize, + bias=False, + wt_dtype=self.dtype, scale_factor=(self.group_size == 0) + ) + ) + mm1 = sum(mm1_to_concat) + mm2 = sum(mm2_to_concat) + else: + mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size, + self.n_splits_linear, wt_dtype=self.dtype, + scale_factor=(self.group_size == 0)) + mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size, + self.n_splits_linear, wt_dtype=self.dtype, + scale_factor=(self.group_size == 0)) + mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] + + if self.n_splits_down_proj == 1: + hidden_states = self.linear( + mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype + ) + else: + invalidInputError(seq_len > 0, "seq_len should be provided if use split linear") + if mode == "prefill": + down_groupsize = self.intermediate_size // self.n_splits_down_proj + hidden_states_to_concat = [] + for i in range(self.n_splits_down_proj): + sub_mm1 = self.slice(mm1, begin=[0, 0, i * down_groupsize], + end=[1, seq_len, (i + 1) * down_groupsize]) + hidden_states_to_concat.append( + self.linear( + sub_mm1, self.hidden_size, down_groupsize, bias=False, + wt_dtype=self.dtype, scale_factor=(self.group_size == 0) + ) + ) + hidden_states = sum(hidden_states_to_concat) + else: + hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size, + self.n_splits_down_proj, wt_dtype=self.dtype, + scale_factor=(self.group_size == 0)) return hidden_states def layer_norm(self, hidden_states, layernorm_weight): @@ -341,6 +491,19 @@ class LLMBaseNNFactory(NNFactory): self.linear_ops.append(op) return op + def dq_split_linear(self, + input_node: ctypes._Pointer, + output_channels: int, + input_channels: int, + n_splits: int, + act_dtype: npt.DTypeLike = np.float16, + wt_dtype: npt.DTypeLike = np.float16, + scale_factor: bool = False): + op = super().dq_split_linear(input_node, n_splits, output_channels, input_channels, + False, act_dtype, wt_dtype, scale_factor) + self.linear_ops.append(op) + return op + def parameter(self, shape): invalidInputError(False, ("parameter should not be called directly, " diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index ab4c1948..b4ad6770 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -42,7 +42,27 @@ from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory from ipex_llm.transformers.npu_models.common import reshape_lm_head_input from transformers.modeling_outputs import CausalLMOutputWithPast from torch.nn import CrossEntropyLoss -from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP +from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2Attention +from ipex_llm.utils.common.log4Error import invalidInputError +from ipex_llm.transformers.npu_models.common import split_linear + + +def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down_proj=2): + attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"] + mlp_module_names = ["down_proj", "up_proj", "gate_proj"] + if isinstance(module, Qwen2Attention): + for name in attn_module_names: + setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name, + n_splits=n_splits_hidden_size)) + delattr(module, name) + elif isinstance(module, Qwen2MLP): + for name in mlp_module_names: + n_splits_mlp = n_splits_hidden_size + if name == 'down_proj': + n_splits_mlp = n_splits_down_proj + setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name, + n_splits=n_splits_mlp)) + delattr(module, name) def split_mlp_down_proj(module: torch.nn.Module): @@ -94,12 +114,18 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory): device: str = "NPU", rms_norm_eps, intermediate_size, + n_splits_linear: int = 1, + n_splits_down_proj: int = 1, + group_size: int = 0 ): super().__init__(max_seq_len=max_seq_len, transpose_value=transpose_value, dtype=dtype, profile=profile, - device=device) + device=device, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size) self.max_seq_len = max_seq_len self.intermediate_size = intermediate_size self.dtype = dtype @@ -221,32 +247,9 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory): new_key_states = self.convert_to_fp16(curr_key_values[i][0]) new_value_states = self.convert_to_fp16(curr_key_values[i][1]) - print("start compiling") + print(f"{mode} start compiling") self.compile() - print("end compiling") - - def mlp(self, hidden_states, seq_len): - mm1 = self.linear( - hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype - ) - mm2 = self.linear( - hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype - ) # type: ignore[attr-defined] - mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] - if self.intermediate_size == 18944: - # for qwen2-7b - mm1_0 = self.slice(mm1, begin=[0, 0, 0], end=[1, seq_len, 9472]) - mm1_1 = self.slice(mm1, begin=[0, 0, 9472], end=[1, seq_len, 18944]) - hidden_states_0 = self.linear(mm1_0, self.hidden_size, 9472, - bias=False, wt_dtype=self.dtype) - hidden_states_1 = self.linear(mm1_1, self.hidden_size, 9472, - bias=False, wt_dtype=self.dtype) - hidden_states = hidden_states_0 + hidden_states_1 - else: - hidden_states = self.linear( - mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype - ) - return hidden_states + print(f"{mode} end compiling") def build_decoder( self, @@ -285,7 +288,7 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory): hidden_states = self.eltwise_add(residual, attn_output) residual = hidden_states hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight) - hidden_states = self.mlp(hidden_states, self.seq_len) + hidden_states = self.mlp(hidden_states, self.seq_len, self.mode) hidden_states = self.eltwise_add(residual, hidden_states) hidden_states = self.convert_to_fp16(hidden_states) @@ -314,6 +317,9 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module): max_seq_len: int = 1024, transpose_value: bool = False, do_print: bool = False, + n_splits_linear: int = 1, + n_splits_down_proj: int = 1, + group_size: int = 0, ): super().__init__() @@ -323,6 +329,10 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module): for w in parameters: if isinstance(w, tuple): # from QuantizedLinear op_parameters.append((w[0].numpy(), w[1].numpy())) + elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight + op_parameters.append(w.numpy()) + elif isinstance(w, np.ndarray): # scale + op_parameters.append(w) else: op_parameters.append(w.to(torch.float16).numpy()) self.op_parameters = op_parameters @@ -331,6 +341,10 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module): self.transpose_value = transpose_value if isinstance(parameters[0], tuple): np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 + elif parameters[0].dtype == torch.int8: + np_dtype = np.int8 + elif parameters[0].dtype == torch.uint8: + np_dtype = np.uint8 else: # FP16 Linear np_dtype = np.float16 @@ -368,6 +382,9 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module): mode="decode", transpose_value=self.transpose_value, dtype=np_dtype, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) self.backend_decoders.append(decoder) @@ -450,6 +467,9 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module): intermediate_size, max_seq_len: int = 128, transpose_value: bool = False, + n_splits_linear: int = 1, + n_splits_down_proj: int = 1, + group_size: int = 0, ): super().__init__() self.op_parameters = parameters @@ -478,6 +498,9 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module): mode="prefill", transpose_value=self.transpose_value, dtype=np_dtype, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) self.layer_norm_0 = layer_norm_0 self.layer_norm_1 = layer_norm_1 @@ -554,6 +577,7 @@ def run_decode( head_dim = model.model.layers[layer_start].self_attn.head_dim rms_norm_eps = model.config.rms_norm_eps intermediate_size = model.config.intermediate_size + group_size = getattr(model.config, "group_size", 0) layer_weights = [] input_layer_norm_weights = [] post_attn_layernorm_weights = [] @@ -561,34 +585,56 @@ def run_decode( k_biases = [] v_biases = [] layer_indexs = range(layer_start, layer_end) + n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) + n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) for layer_idx in layer_indexs: curr_layer = model.model.layers[layer_idx] attn_layer = curr_layer.self_attn mlp_layer = curr_layer.mlp - if model.config.intermediate_size == 8960: - # for qwen2-1.5b - weights = [ - (attn_layer.q_proj.weight, attn_layer.q_proj.scale), - (attn_layer.k_proj.weight, attn_layer.k_proj.scale), - (attn_layer.v_proj.weight, attn_layer.v_proj.scale), - (attn_layer.o_proj.weight, attn_layer.o_proj.scale), - (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), - (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), - (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), - ] - elif model.config.intermediate_size == 18944: - # for qwen2-7b - weights = [ - (attn_layer.q_proj.weight, attn_layer.q_proj.scale), - (attn_layer.k_proj.weight, attn_layer.k_proj.scale), - (attn_layer.v_proj.weight, attn_layer.v_proj.scale), - (attn_layer.o_proj.weight, attn_layer.o_proj.scale), - (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), - (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), - (mlp_layer.down_proj_0.weight, mlp_layer.down_proj_0.scale), - (mlp_layer.down_proj_1.weight, mlp_layer.down_proj_1.scale) - ] + weights = [] + if n_splits_linear == 1: + for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list): + weights.append((q.weight, q.scale)) + weights.append((k.weight, k.scale)) + weights.append((v.weight, v.scale)) + + for l in attn_layer.o_proj_dq_list: + weights.append((l.weight, l.scale)) + else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + + if n_splits_linear == 1: + for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list): + weights.append((g.weight, g.scale)) + weights.append((u.weight, u.scale)) + else: + for layer_list in [mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + + if n_splits_down_proj == 1: + for l in mlp_layer.down_proj_dq_list: + weights.append((l.weight, l.scale)) + else: + l_weights = [] + scales = [] + for l in mlp_layer.down_proj_dq_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -598,9 +644,9 @@ def run_decode( layer_weights.extend(weights) input_layer_norm_weights.append(layer_norm_0) post_attn_layernorm_weights.append(layer_norm_1) - q_biases.append(attn_layer.q_proj.bias.to(torch.float16)) - k_biases.append(attn_layer.k_proj.bias.to(torch.float16)) - v_biases.append(attn_layer.v_proj.bias.to(torch.float16)) + q_biases.append(attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16)) + k_biases.append(attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16)) + v_biases.append(attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16)) multi_decoder = FusedQwenLowBitMultiDecoderlayer( parameters=layer_weights, @@ -621,6 +667,9 @@ def run_decode( max_seq_len=max_seq_len, transpose_value=transpose_value_cache, do_print=False, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) dist.barrier() @@ -703,11 +752,15 @@ class DecodeRunner: self.forward_signal = torch.tensor(0, dtype=torch.int) + n_layers_per_rank = num_layers // (world_size - 1) + if num_layers % (world_size - 1) > 0: + n_layers_per_rank += 1 + for rank in range(1, world_size): input_q = mp.Queue() output_q = mp.Queue() - start_layer = (rank - 1) * (num_layers // (world_size - 1)) - end_layer = (rank) * (num_layers // (world_size - 1)) + start_layer = (rank - 1) * n_layers_per_rank + end_layer = (rank) * n_layers_per_rank if rank == world_size - 1: end_layer = num_layers p = mp.Process( @@ -787,39 +840,34 @@ def run_prefill( head_dim = model.model.layers[layer_start].self_attn.head_dim rms_norm_eps = model.config.rms_norm_eps intermediate_size = model.config.intermediate_size + group_size = getattr(model.config, "group_size", 0) deocderlayers = [] layer_weights = [] input_layer_norm_weights = [] post_attn_layernorm_weights = [] layer_indexs = range(layer_start, layer_end) + n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) + n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) for layer_idx in layer_indexs: curr_layer = model.model.layers[layer_idx] attn_layer = curr_layer.self_attn mlp_layer = curr_layer.mlp - if model.config.intermediate_size == 8960: - # for qwen2-1.5b - weights = [ - (attn_layer.q_proj.weight, attn_layer.q_proj.scale), - (attn_layer.k_proj.weight, attn_layer.k_proj.scale), - (attn_layer.v_proj.weight, attn_layer.v_proj.scale), - (attn_layer.o_proj.weight, attn_layer.o_proj.scale), - (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), - (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), - (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), - ] - elif model.config.intermediate_size == 18944: - # for qwen2-7b - weights = [ - (attn_layer.q_proj.weight, attn_layer.q_proj.scale), - (attn_layer.k_proj.weight, attn_layer.k_proj.scale), - (attn_layer.v_proj.weight, attn_layer.v_proj.scale), - (attn_layer.o_proj.weight, attn_layer.o_proj.scale), - (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), - (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), - (mlp_layer.down_proj_0.weight, mlp_layer.down_proj_0.scale), - (mlp_layer.down_proj_1.weight, mlp_layer.down_proj_1.scale) - ] + weights = [] + + for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list): + weights.append((q.weight, q.scale)) + weights.append((k.weight, k.scale)) + weights.append((v.weight, v.scale)) + + for l in attn_layer.o_proj_dq_list: + weights.append((l.weight, l.scale)) + for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list): + weights.append((g.weight, g.scale)) + weights.append((u.weight, u.scale)) + for l in mlp_layer.down_proj_dq_list: + weights.append((l.weight, l.scale)) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -835,14 +883,17 @@ def run_prefill( cached_sin=cached_sin, layer_norm_0=layer_norm_0, layer_norm_1=layer_norm_1, - q_bias=attn_layer.q_proj.bias.to(torch.float16), - k_bias=attn_layer.k_proj.bias.to(torch.float16), - v_bias=attn_layer.v_proj.bias.to(torch.float16), + q_bias=attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16), + k_bias=attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16), + v_bias=attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16), layer_idx=layer_idx, rms_norm_eps=rms_norm_eps, intermediate_size=intermediate_size, max_seq_len=max_output_len, transpose_value=transpose_value_cache, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size ) layer_weights.extend(weights)