[NPU] Groupwise (#12241)
* dq divide * fix * support attn divide * update qwen2 7b * divide down_proj & other linear * use concat & reduce sum * support scale after * support qwen2 * w/ mm * update reshape * spda * split * split 2+ * update * lm head-> 28 * no scale * update * update * update * fix style * fix style * to split linear * update * update code * address comments * fix style & remove redundant code & revert benchmark scripts * fix style & remove code * update save & load --------- Co-authored-by: Yang Wang <yang3.wang@intel.com>
This commit is contained in:
		
							parent
							
								
									aedc4edfba
								
							
						
					
					
						commit
						e37f951cce
					
				
					 9 changed files with 493 additions and 165 deletions
				
			
		| 
						 | 
					@ -30,7 +30,9 @@ current_dir = os.path.dirname(os.path.realpath(__file__))
 | 
				
			||||||
def save_npu_model_in_low_bit(repo_id,
 | 
					def save_npu_model_in_low_bit(repo_id,
 | 
				
			||||||
                          local_model_hub,
 | 
					                          local_model_hub,
 | 
				
			||||||
                          low_bit,
 | 
					                          low_bit,
 | 
				
			||||||
                          max_output_len, max_prompt_len, intra_pp, inter_pp, disable_transpose_value_cache):
 | 
					                          max_output_len, max_prompt_len, intra_pp, inter_pp,
 | 
				
			||||||
 | 
					                          disable_transpose_value_cache,
 | 
				
			||||||
 | 
					                          quantization_group_size):
 | 
				
			||||||
    model_path = get_model_path(repo_id, local_model_hub)
 | 
					    model_path = get_model_path(repo_id, local_model_hub)
 | 
				
			||||||
    # Load model in 4 bit,
 | 
					    # Load model in 4 bit,
 | 
				
			||||||
    # which convert the relevant layers in the model into INT4 format
 | 
					    # which convert the relevant layers in the model into INT4 format
 | 
				
			||||||
| 
						 | 
					@ -47,6 +49,7 @@ def save_npu_model_in_low_bit(repo_id,
 | 
				
			||||||
            intra_pp=intra_pp,
 | 
					            intra_pp=intra_pp,
 | 
				
			||||||
            inter_pp=inter_pp,
 | 
					            inter_pp=inter_pp,
 | 
				
			||||||
            transpose_value_cache=not disable_transpose_value_cache,
 | 
					            transpose_value_cache=not disable_transpose_value_cache,
 | 
				
			||||||
 | 
					            quantization_group_size=quantization_group_size
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
					    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
    end = time.perf_counter()
 | 
					    end = time.perf_counter()
 | 
				
			||||||
| 
						 | 
					@ -54,6 +57,7 @@ def save_npu_model_in_low_bit(repo_id,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model.save_low_bit(model_path+'-npu-'+low_bit)
 | 
					    model.save_low_bit(model_path+'-npu-'+low_bit)
 | 
				
			||||||
    tokenizer.save_pretrained(model_path+'-npu-'+low_bit)
 | 
					    tokenizer.save_pretrained(model_path+'-npu-'+low_bit)
 | 
				
			||||||
 | 
					    print(f"Model saved to {model_path+'-npu-'+low_bit}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
| 
						 | 
					@ -65,6 +69,7 @@ if __name__ == "__main__":
 | 
				
			||||||
    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
					    parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
 | 
				
			||||||
    parser.add_argument("--intra-pp", type=int, default=2)
 | 
					    parser.add_argument("--intra-pp", type=int, default=2)
 | 
				
			||||||
    parser.add_argument("--inter-pp", type=int, default=2)
 | 
					    parser.add_argument("--inter-pp", type=int, default=2)
 | 
				
			||||||
 | 
					    parser.add_argument("--quantization_group_size", type=int, default=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
    from omegaconf import OmegaConf
 | 
					    from omegaconf import OmegaConf
 | 
				
			||||||
| 
						 | 
					@ -78,5 +83,6 @@ if __name__ == "__main__":
 | 
				
			||||||
                              max_prompt_len=args.max_prompt_len,
 | 
					                              max_prompt_len=args.max_prompt_len,
 | 
				
			||||||
                              intra_pp=args.intra_pp,
 | 
					                              intra_pp=args.intra_pp,
 | 
				
			||||||
                              inter_pp=args.inter_pp,
 | 
					                              inter_pp=args.inter_pp,
 | 
				
			||||||
                              disable_transpose_value_cache=args.disable_transpose_value_cache
 | 
					                              disable_transpose_value_cache=args.disable_transpose_value_cache,
 | 
				
			||||||
 | 
					                              quantization_group_size=args.quantization_group_size,
 | 
				
			||||||
                              )
 | 
					                              )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -81,6 +81,8 @@ class _BaseAutoModelClass:
 | 
				
			||||||
        :param mixed_precision: boolean value, Whether to use mixed precision quantization.
 | 
					        :param mixed_precision: boolean value, Whether to use mixed precision quantization.
 | 
				
			||||||
            Default to be False. If set to ``True``, we will use ``'sym_int8'`` for lm_head when
 | 
					            Default to be False. If set to ``True``, we will use ``'sym_int8'`` for lm_head when
 | 
				
			||||||
            ``load_in_low_bit`` is '``sym_int4``' for certain models.
 | 
					            ``load_in_low_bit`` is '``sym_int4``' for certain models.
 | 
				
			||||||
 | 
					        :param quantization_group_size: int, quantization group size, The recommended
 | 
				
			||||||
 | 
					            quantization_group_size are 0, 32, 64 or 128
 | 
				
			||||||
        :return: a model instance
 | 
					        :return: a model instance
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if kwargs.get("device_map", None) not in [None, "cpu", "auto"]:
 | 
					        if kwargs.get("device_map", None) not in [None, "cpu", "auto"]:
 | 
				
			||||||
| 
						 | 
					@ -126,6 +128,15 @@ class _BaseAutoModelClass:
 | 
				
			||||||
        transpose_value_cache = kwargs.pop("transpose_value_cache", True)
 | 
					        transpose_value_cache = kwargs.pop("transpose_value_cache", True)
 | 
				
			||||||
        modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
 | 
					        modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
 | 
				
			||||||
        mixed_precision = kwargs.pop('mixed_precision', False)
 | 
					        mixed_precision = kwargs.pop('mixed_precision', False)
 | 
				
			||||||
 | 
					        quantization_group_size = kwargs.pop("quantization_group_size", 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        invalidInputError(
 | 
				
			||||||
 | 
					            quantization_group_size in [0, 32, 64, 128],
 | 
				
			||||||
 | 
					            (
 | 
				
			||||||
 | 
					                "The recommended quantization_group_size are 0, 32, 64 or 128,"
 | 
				
			||||||
 | 
					                f"but got {quantization_group_size}"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        _args = copy.deepcopy(args)
 | 
					        _args = copy.deepcopy(args)
 | 
				
			||||||
        _kwargs = copy.deepcopy(kwargs)
 | 
					        _kwargs = copy.deepcopy(kwargs)
 | 
				
			||||||
| 
						 | 
					@ -162,8 +173,11 @@ class _BaseAutoModelClass:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            with torch.no_grad():
 | 
					            with torch.no_grad():
 | 
				
			||||||
                model.config.update({"mixed_precision": mixed_precision})
 | 
					                model.config.update({"mixed_precision": mixed_precision})
 | 
				
			||||||
                optimize_llm_pre(model, qtype, mixed_precision)
 | 
					                model.config.update({"group_size": quantization_group_size})
 | 
				
			||||||
                cls.load_convert(qtype, model, "cpu", modules_to_not_convert, *args, **kwargs)
 | 
					                optimize_llm_pre(model, qtype, mixed_precision,
 | 
				
			||||||
 | 
					                                 quantization_group_size=quantization_group_size)
 | 
				
			||||||
 | 
					                cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
 | 
				
			||||||
 | 
					                                 quantization_group_size, *args, **kwargs)
 | 
				
			||||||
                create_npu_kernels(llm)
 | 
					                create_npu_kernels(llm)
 | 
				
			||||||
            model = model.eval()
 | 
					            model = model.eval()
 | 
				
			||||||
            logger.info(f"Finish to convert model")
 | 
					            logger.info(f"Finish to convert model")
 | 
				
			||||||
| 
						 | 
					@ -177,6 +191,7 @@ class _BaseAutoModelClass:
 | 
				
			||||||
                inter_pp=inter_pp,
 | 
					                inter_pp=inter_pp,
 | 
				
			||||||
                intra_pp=intra_pp,
 | 
					                intra_pp=intra_pp,
 | 
				
			||||||
                transpose_value_cache=transpose_value_cache,
 | 
					                transpose_value_cache=transpose_value_cache,
 | 
				
			||||||
 | 
					                group_size=quantization_group_size
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            model.save_low_bit = types.MethodType(save_low_bit, model)
 | 
					            model.save_low_bit = types.MethodType(save_low_bit, model)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
| 
						 | 
					@ -197,11 +212,13 @@ class _BaseAutoModelClass:
 | 
				
			||||||
        return model
 | 
					        return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def load_convert(cls, q_k, optimize_model, device, modules_to_not_convert, *arg, **kwarg):
 | 
					    def load_convert(cls, q_k, optimize_model, device, modules_to_not_convert,
 | 
				
			||||||
 | 
					                     group_size=0, *arg, **kwarg):
 | 
				
			||||||
        from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear
 | 
					        from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        replace_with_QuantizedLinear(optimize_model, q_k, device=device,
 | 
					        replace_with_QuantizedLinear(optimize_model, q_k, device=device,
 | 
				
			||||||
                                     modules_to_not_convert=modules_to_not_convert)
 | 
					                                     modules_to_not_convert=modules_to_not_convert,
 | 
				
			||||||
 | 
					                                     group_size=group_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
 | 
					    @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
 | 
				
			||||||
| 
						 | 
					@ -214,6 +231,7 @@ class _BaseAutoModelClass:
 | 
				
			||||||
        ignore_argument(kwargs, "speculative")
 | 
					        ignore_argument(kwargs, "speculative")
 | 
				
			||||||
        ignore_argument(kwargs, "pipeline_parallel_stages")
 | 
					        ignore_argument(kwargs, "pipeline_parallel_stages")
 | 
				
			||||||
        ignore_argument(kwargs, "mixed_precision")
 | 
					        ignore_argument(kwargs, "mixed_precision")
 | 
				
			||||||
 | 
					        ignore_argument(kwargs, "quantization_group_size")
 | 
				
			||||||
        optimize_model = kwargs.pop("optimize_model", False)
 | 
					        optimize_model = kwargs.pop("optimize_model", False)
 | 
				
			||||||
        max_output_len = kwargs.pop("max_output_len", 1024)
 | 
					        max_output_len = kwargs.pop("max_output_len", 1024)
 | 
				
			||||||
        max_prompt_len = kwargs.pop("max_prompt_len", 512)
 | 
					        max_prompt_len = kwargs.pop("max_prompt_len", 512)
 | 
				
			||||||
| 
						 | 
					@ -264,6 +282,7 @@ class _BaseAutoModelClass:
 | 
				
			||||||
        qtype = config_dict.pop("bigdl_transformers_low_bit", False)
 | 
					        qtype = config_dict.pop("bigdl_transformers_low_bit", False)
 | 
				
			||||||
        bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True)
 | 
					        bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True)
 | 
				
			||||||
        mixed_precision = config_dict.pop("mixed_precision", False)
 | 
					        mixed_precision = config_dict.pop("mixed_precision", False)
 | 
				
			||||||
 | 
					        quantization_group_size = config_dict.pop("group_size", 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        invalidInputError(
 | 
					        invalidInputError(
 | 
				
			||||||
            qtype,
 | 
					            qtype,
 | 
				
			||||||
| 
						 | 
					@ -376,9 +395,10 @@ class _BaseAutoModelClass:
 | 
				
			||||||
                llm = model
 | 
					                llm = model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            with torch.no_grad():
 | 
					            with torch.no_grad():
 | 
				
			||||||
                optimize_llm_pre(model, qtype, mixed_precision)
 | 
					                optimize_llm_pre(model, qtype, mixed_precision,
 | 
				
			||||||
 | 
					                                 quantization_group_size=quantization_group_size)
 | 
				
			||||||
                cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
 | 
					                cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
 | 
				
			||||||
                                 *model_args, **kwargs)
 | 
					                                 quantization_group_size, *model_args, **kwargs)
 | 
				
			||||||
                create_npu_kernels(llm)
 | 
					                create_npu_kernels(llm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
| 
						 | 
					@ -458,6 +478,7 @@ class _BaseAutoModelClass:
 | 
				
			||||||
                inter_pp=inter_pp,
 | 
					                inter_pp=inter_pp,
 | 
				
			||||||
                intra_pp=intra_pp,
 | 
					                intra_pp=intra_pp,
 | 
				
			||||||
                transpose_value_cache=transpose_value_cache,
 | 
					                transpose_value_cache=transpose_value_cache,
 | 
				
			||||||
 | 
					                group_size=quantization_group_size
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return model
 | 
					        return model
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,6 +16,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from typing import List
 | 
					from typing import List
 | 
				
			||||||
 | 
					from ipex_llm.utils.common.log4Error import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
 | 
					def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
 | 
				
			||||||
| 
						 | 
					@ -40,3 +41,21 @@ def reshape_lm_head_input(x):
 | 
				
			||||||
        shape[1] = 1
 | 
					        shape[1] = 1
 | 
				
			||||||
        x = x[:, -1, :].view(shape)
 | 
					        x = x[:, -1, :].view(shape)
 | 
				
			||||||
    return x
 | 
					    return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def split_linear(module, module_name, n_splits=2):
 | 
				
			||||||
 | 
					    in_features = module.in_features
 | 
				
			||||||
 | 
					    invalidInputError(in_features % n_splits == 0,
 | 
				
			||||||
 | 
					                      f"in_features of the linear layer {module_name} must be divisible by"
 | 
				
			||||||
 | 
					                      f" n_splits, but got in_features: {in_features}, n_splits: {n_splits}")
 | 
				
			||||||
 | 
					    weight_split = torch.tensor_split(module.weight, n_splits, dim=1)
 | 
				
			||||||
 | 
					    linear_list = torch.nn.ModuleList()
 | 
				
			||||||
 | 
					    bias = module.bias
 | 
				
			||||||
 | 
					    for idx, weight in enumerate(weight_split):
 | 
				
			||||||
 | 
					        new_linear = torch.nn.Linear(weight.size(1),
 | 
				
			||||||
 | 
					                                     weight.size(0),
 | 
				
			||||||
 | 
					                                     bias=False if bias is None else True)
 | 
				
			||||||
 | 
					        new_linear.bias = bias
 | 
				
			||||||
 | 
					        new_linear.weight = torch.nn.Parameter(weight.contiguous(), requires_grad=False)
 | 
				
			||||||
 | 
					        linear_list.add_module(f"{module_name}_dq_{idx}", new_linear)
 | 
				
			||||||
 | 
					    return linear_list
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -31,7 +31,8 @@ def module_optimization(func) -> torch.nn.Module:
 | 
				
			||||||
        torch.nn.Module: optimized module
 | 
					        torch.nn.Module: optimized module
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def wrapper(model: torch.nn.Module, qtype, device, modules_to_not_convert, *args, **kwargs):
 | 
					    def wrapper(model: torch.nn.Module, qtype, device, modules_to_not_convert,
 | 
				
			||||||
 | 
					                group_size=0, *args, **kwargs):
 | 
				
			||||||
        """Recursively apply the optimization function.
 | 
					        """Recursively apply the optimization function.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
| 
						 | 
					@ -42,18 +43,22 @@ def module_optimization(func) -> torch.nn.Module:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        for name, layer in model.named_children():
 | 
					        for name, layer in model.named_children():
 | 
				
			||||||
            if name not in modules_to_not_convert:
 | 
					            if name not in modules_to_not_convert:
 | 
				
			||||||
                new_layer = func(layer, qtype, device, modules_to_not_convert, *args, **kwargs)
 | 
					                new_layer = func(layer, qtype, device, modules_to_not_convert,
 | 
				
			||||||
 | 
					                                 group_size=group_size, *args, **kwargs)
 | 
				
			||||||
                if new_layer:
 | 
					                if new_layer:
 | 
				
			||||||
                    model.add_module(name, new_layer)
 | 
					                    model.add_module(name, new_layer)
 | 
				
			||||||
                    wrapper(new_layer, qtype, device, modules_to_not_convert, *args, **kwargs)
 | 
					                    wrapper(new_layer, qtype, device, modules_to_not_convert,
 | 
				
			||||||
 | 
					                            group_size=group_size, *args, **kwargs)
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    wrapper(layer, qtype, device, modules_to_not_convert, *args, **kwargs)
 | 
					                    wrapper(layer, qtype, device, modules_to_not_convert,
 | 
				
			||||||
 | 
					                            group_size=group_size, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return wrapper
 | 
					    return wrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@module_optimization
 | 
					@module_optimization
 | 
				
			||||||
def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert):
 | 
					def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
 | 
				
			||||||
 | 
					                                 group_size):
 | 
				
			||||||
    from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
 | 
					    from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
 | 
				
			||||||
    from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
					    from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
    iqtype = ggml_tensor_qtype[qtype]
 | 
					    iqtype = ggml_tensor_qtype[qtype]
 | 
				
			||||||
| 
						 | 
					@ -66,7 +71,8 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert):
 | 
				
			||||||
                iqtype = ggml_tensor_qtype[qtype]
 | 
					                iqtype = ggml_tensor_qtype[qtype]
 | 
				
			||||||
        qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
 | 
					        qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
 | 
				
			||||||
                                             iqtype, device=device)
 | 
					                                             iqtype, device=device)
 | 
				
			||||||
        return QuantizedLinear(qweights, scale, layer.bias)
 | 
					        return QuantizedLinear(qweights, scale, layer.bias,
 | 
				
			||||||
 | 
					                               group_size=group_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_forward(m, target_m, new_forward):
 | 
					def convert_forward(m, target_m, new_forward):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,6 +19,7 @@ import importlib
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params
 | 
					from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params
 | 
				
			||||||
from ipex_llm.transformers.npu_models.lm_head import LMHeadLinear, SlicedLMHead
 | 
					from ipex_llm.transformers.npu_models.lm_head import LMHeadLinear, SlicedLMHead
 | 
				
			||||||
 | 
					from ipex_llm.utils.common.log4Error import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_forward(m, target_m, new_forward):
 | 
					def convert_forward(m, target_m, new_forward):
 | 
				
			||||||
| 
						 | 
					@ -29,7 +30,8 @@ def convert_forward(m, target_m, new_forward):
 | 
				
			||||||
        convert_forward(sub_m, target_m, new_forward)
 | 
					        convert_forward(sub_m, target_m, new_forward)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision):
 | 
					def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
 | 
				
			||||||
 | 
					                     quantization_group_size=0):
 | 
				
			||||||
    if model.config.model_type == "baichuan":
 | 
					    if model.config.model_type == "baichuan":
 | 
				
			||||||
        # process NormHead module in Baichuan2 7B
 | 
					        # process NormHead module in Baichuan2 7B
 | 
				
			||||||
        if hasattr(model, 'lm_head') and model.lm_head is not None:
 | 
					        if hasattr(model, 'lm_head') and model.lm_head is not None:
 | 
				
			||||||
| 
						 | 
					@ -86,17 +88,40 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision):
 | 
				
			||||||
        model = model.llm
 | 
					        model = model.llm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if model.config.model_type == "qwen2":
 | 
					    if model.config.model_type == "qwen2":
 | 
				
			||||||
        from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj
 | 
					        from ipex_llm.transformers.npu_models.qwen2_mp import split_linears
 | 
				
			||||||
        model.apply(split_mlp_down_proj)
 | 
					
 | 
				
			||||||
 | 
					        if quantization_group_size == 0:
 | 
				
			||||||
 | 
					            n_splits_linear = 1
 | 
				
			||||||
 | 
					            n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            invalidInputError(
 | 
				
			||||||
 | 
					                model.config.hidden_size % quantization_group_size == 0 and
 | 
				
			||||||
 | 
					                model.config.intermediate_size % quantization_group_size == 0,
 | 
				
			||||||
 | 
					                "The model hidden_size and intermediate_size should be divisible by "
 | 
				
			||||||
 | 
					                f"quantization_group_size, but got hidden_size: {model.config.hidden_size}, "
 | 
				
			||||||
 | 
					                f"intermediate_size: {model.config.intermediate_size}, and "
 | 
				
			||||||
 | 
					                f"quantization_group_size: {quantization_group_size}"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            n_splits_linear = model.config.hidden_size // quantization_group_size
 | 
				
			||||||
 | 
					            n_splits_down_proj = model.config.intermediate_size // quantization_group_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        model.apply(lambda m: split_linears(m, n_splits_hidden_size=n_splits_linear,
 | 
				
			||||||
 | 
					                                            n_splits_down_proj=n_splits_down_proj))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # for Qwen2-7B-Insturct, divide lm_head into 14 parts
 | 
					        # for Qwen2-7B-Insturct, divide lm_head into 14 parts
 | 
				
			||||||
        if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
 | 
					        if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
 | 
				
			||||||
                not cpu_lm_head:
 | 
					                not cpu_lm_head:
 | 
				
			||||||
            # Do not split lm_head and use sym_int8 instead when mixed_precison is True
 | 
					            # Do not split lm_head and use sym_int8 instead when mixed_precison is True
 | 
				
			||||||
 | 
					            if quantization_group_size != 0:
 | 
				
			||||||
 | 
					                split_num = model.config.hidden_size // quantization_group_size
 | 
				
			||||||
 | 
					                new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
 | 
				
			||||||
 | 
					                                           bias=model.lm_head.bias, use_split=True)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # Do not split lm_head and use sym_int8 instead when mixed_precison is True
 | 
				
			||||||
                is_split = (not mixed_precision) and qtype == "sym_int4_rtn"
 | 
					                is_split = (not mixed_precision) and qtype == "sym_int4_rtn"
 | 
				
			||||||
                split_num = 14 if is_split else 1
 | 
					                split_num = 14 if is_split else 1
 | 
				
			||||||
                new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
 | 
					                new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
 | 
				
			||||||
                                       bias=model.lm_head.bias)
 | 
					                                           bias=model.lm_head.bias, use_split=False)
 | 
				
			||||||
            del model.lm_head
 | 
					            del model.lm_head
 | 
				
			||||||
            model.lm_head = new_lm_head
 | 
					            model.lm_head = new_lm_head
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -132,6 +157,7 @@ def optimize_llm(
 | 
				
			||||||
    inter_pp=None,
 | 
					    inter_pp=None,
 | 
				
			||||||
    intra_pp=None,
 | 
					    intra_pp=None,
 | 
				
			||||||
    transpose_value_cache=True,
 | 
					    transpose_value_cache=True,
 | 
				
			||||||
 | 
					    group_size=0
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    if model.config.model_type == "llama":
 | 
					    if model.config.model_type == "llama":
 | 
				
			||||||
        if intra_pp is None:
 | 
					        if intra_pp is None:
 | 
				
			||||||
| 
						 | 
					@ -168,7 +194,13 @@ def optimize_llm(
 | 
				
			||||||
        if intra_pp is None:
 | 
					        if intra_pp is None:
 | 
				
			||||||
            intra_pp = 2
 | 
					            intra_pp = 2
 | 
				
			||||||
        if inter_pp is None:
 | 
					        if inter_pp is None:
 | 
				
			||||||
            inter_pp = 2 if model.config.intermediate_size == 18944 else 1
 | 
					            if model.config.intermediate_size == 18944:
 | 
				
			||||||
 | 
					                if group_size != 0:
 | 
				
			||||||
 | 
					                    inter_pp = 5
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    inter_pp = 2
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                inter_pp = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward
 | 
					        from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward
 | 
				
			||||||
        from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner
 | 
					        from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -130,6 +130,7 @@ class QuantizedLinear(torch.nn.Module):
 | 
				
			||||||
        weight: torch.Tensor,
 | 
					        weight: torch.Tensor,
 | 
				
			||||||
        scale: torch.Tensor,
 | 
					        scale: torch.Tensor,
 | 
				
			||||||
        bias: Optional[torch.Tensor] = None,
 | 
					        bias: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					        group_size: int = False,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        """Initialize the QuantizedLinear class.
 | 
					        """Initialize the QuantizedLinear class.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -154,8 +155,11 @@ class QuantizedLinear(torch.nn.Module):
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        self.outC, self.inC = self.weight.shape
 | 
					        self.outC, self.inC = self.weight.shape
 | 
				
			||||||
 | 
					        if group_size != 0:
 | 
				
			||||||
 | 
					            self.scale = Parameter(scale, requires_grad=False)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
            if self.weight.dtype == torch.uint8:
 | 
					            if self.weight.dtype == torch.uint8:
 | 
				
			||||||
            # In case is Int4 we need to double the input channels because weights are compressed
 | 
					                # Int4 we need to double the input channels because weights are compressed
 | 
				
			||||||
                self.inC *= 2
 | 
					                self.inC *= 2
 | 
				
			||||||
            self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False)
 | 
					            self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False)
 | 
				
			||||||
        self.bias = bias
 | 
					        self.bias = bias
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,10 +13,10 @@
 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch import nn
 | 
					from torch import nn
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from filelock import FileLock
 | 
				
			||||||
from intel_npu_acceleration_library.backend import NNFactory
 | 
					from intel_npu_acceleration_library.backend import NNFactory
 | 
				
			||||||
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
 | 
					from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -34,6 +34,7 @@ class LMHeadLinear(NNFactory):
 | 
				
			||||||
        profile: bool = False,
 | 
					        profile: bool = False,
 | 
				
			||||||
        device: str = "NPU",
 | 
					        device: str = "NPU",
 | 
				
			||||||
        dtype: np.dtype = np.int8,
 | 
					        dtype: np.dtype = np.int8,
 | 
				
			||||||
 | 
					        use_split: bool = False,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        """Initialize the LMHeadLinear class.
 | 
					        """Initialize the LMHeadLinear class.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -51,9 +52,14 @@ class LMHeadLinear(NNFactory):
 | 
				
			||||||
        self.inC, self.outC = inC, outC
 | 
					        self.inC, self.outC = inC, outC
 | 
				
			||||||
        self.batch = batch
 | 
					        self.batch = batch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        input = self.parameter((self.batch, self.inC))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.split_num = split_num
 | 
					        self.split_num = split_num
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if use_split:
 | 
				
			||||||
 | 
					            input = self.parameter((1, self.batch, self.inC))
 | 
				
			||||||
 | 
					            res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype,
 | 
				
			||||||
 | 
					                                       scale_factor=False)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            input = self.parameter((self.batch, self.inC))
 | 
				
			||||||
            split_size = self.inC // split_num // 2 * 2
 | 
					            split_size = self.inC // split_num // 2 * 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for i in range(self.split_num):
 | 
					            for i in range(self.split_num):
 | 
				
			||||||
| 
						 | 
					@ -61,7 +67,8 @@ class LMHeadLinear(NNFactory):
 | 
				
			||||||
                end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
 | 
					                end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
 | 
				
			||||||
                input_slice = self.slice(input, begin=[0, start_idx],
 | 
					                input_slice = self.slice(input, begin=[0, start_idx],
 | 
				
			||||||
                                         end=[self.batch, end_idx])
 | 
					                                         end=[self.batch, end_idx])
 | 
				
			||||||
            linear_slice = self.linear(input_slice, outC, split_size, bias=False, wt_dtype=dtype)
 | 
					                linear_slice = self.linear(input_slice, outC, split_size, bias=False,
 | 
				
			||||||
 | 
					                                           wt_dtype=dtype)
 | 
				
			||||||
                if i == 0:
 | 
					                if i == 0:
 | 
				
			||||||
                    res = linear_slice
 | 
					                    res = linear_slice
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
| 
						 | 
					@ -71,6 +78,14 @@ class LMHeadLinear(NNFactory):
 | 
				
			||||||
        self.compile()
 | 
					        self.compile()
 | 
				
			||||||
        print("end compiling lm_head")
 | 
					        print("end compiling lm_head")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_weights(self, op_id, weights):
 | 
				
			||||||
 | 
					        self.set_weights_async(op_id, weights)
 | 
				
			||||||
 | 
					        with FileLock(f"lmhead_run.lock"):
 | 
				
			||||||
 | 
					            backend_lib.run(self._mm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_weights_async(self, op_id, weights):
 | 
				
			||||||
 | 
					        self.setWeights(1, op_id, *weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run(
 | 
					    def run(
 | 
				
			||||||
        self, X: np.ndarray
 | 
					        self, X: np.ndarray
 | 
				
			||||||
    ) -> np.ndarray:
 | 
					    ) -> np.ndarray:
 | 
				
			||||||
| 
						 | 
					@ -93,7 +108,7 @@ class LMHeadLinear(NNFactory):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SlicedLMHead(nn.Module):
 | 
					class SlicedLMHead(nn.Module):
 | 
				
			||||||
    def __init__(self, weight, bias, split_num):
 | 
					    def __init__(self, weight, bias, split_num, use_split=False):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.split_num = split_num
 | 
					        self.split_num = split_num
 | 
				
			||||||
        self.outC, self.inC = weight.shape
 | 
					        self.outC, self.inC = weight.shape
 | 
				
			||||||
| 
						 | 
					@ -110,6 +125,7 @@ class SlicedLMHead(nn.Module):
 | 
				
			||||||
            new_linear.out_features = new_weight.size(0)
 | 
					            new_linear.out_features = new_weight.size(0)
 | 
				
			||||||
            self.lm_heads.append(new_linear)
 | 
					            self.lm_heads.append(new_linear)
 | 
				
			||||||
        self.bias = bias
 | 
					        self.bias = bias
 | 
				
			||||||
 | 
					        self.use_split = use_split
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, hidden_states):
 | 
					    def forward(self, hidden_states):
 | 
				
			||||||
        if hidden_states.size(0) * hidden_states.size(1) == 1:
 | 
					        if hidden_states.size(0) * hidden_states.size(1) == 1:
 | 
				
			||||||
| 
						 | 
					@ -143,9 +159,19 @@ class SlicedLMHead(nn.Module):
 | 
				
			||||||
    def get_fused_lm_head(self):
 | 
					    def get_fused_lm_head(self):
 | 
				
			||||||
        np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8
 | 
					        np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8
 | 
				
			||||||
        self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num,
 | 
					        self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num,
 | 
				
			||||||
                                          False, "NPU", dtype=np_dtype)
 | 
					                                          False, "NPU", dtype=np_dtype, use_split=self.use_split)
 | 
				
			||||||
 | 
					        if self.use_split:
 | 
				
			||||||
 | 
					            weights = []
 | 
				
			||||||
 | 
					            scales = []
 | 
				
			||||||
 | 
					            for i in range(self.split_num):
 | 
				
			||||||
 | 
					                weights.append(self.lm_heads[i].weight)
 | 
				
			||||||
 | 
					                scales.append(self.lm_heads[i].scale)
 | 
				
			||||||
 | 
					            fused_lm_head_weights = (torch.stack(weights, axis=0).numpy(),
 | 
				
			||||||
 | 
					                                     torch.stack(scales, axis=0).numpy())
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
            fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(),
 | 
					            fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(),
 | 
				
			||||||
                                      self.lm_heads[i].scale.data.numpy())
 | 
					                                      self.lm_heads[i].scale.data.numpy())
 | 
				
			||||||
                                     for i in range(self.split_num)]
 | 
					                                     for i in range(self.split_num)]
 | 
				
			||||||
        self.fused_lm_head.setWeights(1, self.lm_heads[0].op_id,
 | 
					
 | 
				
			||||||
                                      *fused_lm_head_weights)
 | 
					        self.fused_lm_head.set_weights(self.lm_heads[0].op_id,
 | 
				
			||||||
 | 
					                                       fused_lm_head_weights)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -27,6 +27,8 @@ from filelock import FileLock
 | 
				
			||||||
import ctypes
 | 
					import ctypes
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from typing import Optional, Any, List
 | 
				
			||||||
 | 
					import numpy.typing as npt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -60,6 +62,12 @@ def run_model(
 | 
				
			||||||
            op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy()))
 | 
					            op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy()))
 | 
				
			||||||
            op_args_flatten.append(op_args[-1][0])
 | 
					            op_args_flatten.append(op_args[-1][0])
 | 
				
			||||||
            op_args_flatten.append(op_args[-1][1])
 | 
					            op_args_flatten.append(op_args[-1][1])
 | 
				
			||||||
 | 
					        elif w.dtype in [torch.int8, torch.uint8]:    # QuantizedLinear weight
 | 
				
			||||||
 | 
					            op_args.append(w.numpy())
 | 
				
			||||||
 | 
					            op_args_flatten.append(op_args[-1])
 | 
				
			||||||
 | 
					        elif isinstance(w, np.ndarray):     # scale
 | 
				
			||||||
 | 
					            op_args.append(w)
 | 
				
			||||||
 | 
					            op_args_flatten.append(op_args[-1])
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            op_args.append(set_contiguous(w).to(torch.float16).numpy())
 | 
					            op_args.append(set_contiguous(w).to(torch.float16).numpy())
 | 
				
			||||||
            op_args_flatten.append(op_args[-1])
 | 
					            op_args_flatten.append(op_args[-1])
 | 
				
			||||||
| 
						 | 
					@ -94,7 +102,8 @@ def run_model(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LLMBaseNNFactory(NNFactory):
 | 
					class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, max_seq_len, transpose_value, dtype, profile=False, device="NPU"):
 | 
					    def __init__(self, max_seq_len, transpose_value, dtype, profile=False, device="NPU",
 | 
				
			||||||
 | 
					                 n_splits_linear=1, n_splits_down_proj=1, group_size=False):
 | 
				
			||||||
        super().__init__(profile, device)
 | 
					        super().__init__(profile, device)
 | 
				
			||||||
        self.cache_parameter_ops = []
 | 
					        self.cache_parameter_ops = []
 | 
				
			||||||
        self.input_ops = []
 | 
					        self.input_ops = []
 | 
				
			||||||
| 
						 | 
					@ -104,6 +113,9 @@ class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
        self.max_seq_len = max_seq_len
 | 
					        self.max_seq_len = max_seq_len
 | 
				
			||||||
        self.transpose_value = transpose_value
 | 
					        self.transpose_value = transpose_value
 | 
				
			||||||
        self.dtype = dtype
 | 
					        self.dtype = dtype
 | 
				
			||||||
 | 
					        self.n_splits_linear = n_splits_linear
 | 
				
			||||||
 | 
					        self.n_splits_down_proj = n_splits_down_proj
 | 
				
			||||||
 | 
					        self.group_size = group_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def attention(self,
 | 
					    def attention(self,
 | 
				
			||||||
                  *,
 | 
					                  *,
 | 
				
			||||||
| 
						 | 
					@ -124,6 +136,8 @@ class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
                  v_bias=None):
 | 
					                  v_bias=None):
 | 
				
			||||||
        hidden_size = num_heads * head_dim
 | 
					        hidden_size = num_heads * head_dim
 | 
				
			||||||
        num_key_value_groups = num_heads // num_key_value_heads
 | 
					        num_key_value_groups = num_heads // num_key_value_heads
 | 
				
			||||||
 | 
					        groupsize = hidden_size // self.n_splits_linear
 | 
				
			||||||
 | 
					        if self.n_splits_linear == 1:
 | 
				
			||||||
            query_states = self.linear(
 | 
					            query_states = self.linear(
 | 
				
			||||||
                hidden_states,
 | 
					                hidden_states,
 | 
				
			||||||
                num_heads * head_dim,
 | 
					                num_heads * head_dim,
 | 
				
			||||||
| 
						 | 
					@ -131,8 +145,7 @@ class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
                bias=False,
 | 
					                bias=False,
 | 
				
			||||||
                wt_dtype=self.dtype,
 | 
					                wt_dtype=self.dtype,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        if q_bias is not None:
 | 
					
 | 
				
			||||||
            query_states = query_states + q_bias
 | 
					 | 
				
			||||||
            key_states = self.linear(
 | 
					            key_states = self.linear(
 | 
				
			||||||
                hidden_states,
 | 
					                hidden_states,
 | 
				
			||||||
                num_key_value_heads * head_dim,
 | 
					                num_key_value_heads * head_dim,
 | 
				
			||||||
| 
						 | 
					@ -140,8 +153,7 @@ class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
                bias=False,
 | 
					                bias=False,
 | 
				
			||||||
                wt_dtype=self.dtype,
 | 
					                wt_dtype=self.dtype,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        if k_bias is not None:
 | 
					
 | 
				
			||||||
            key_states = key_states + k_bias
 | 
					 | 
				
			||||||
            value_states = self.linear(
 | 
					            value_states = self.linear(
 | 
				
			||||||
                hidden_states,
 | 
					                hidden_states,
 | 
				
			||||||
                num_key_value_heads * head_dim,
 | 
					                num_key_value_heads * head_dim,
 | 
				
			||||||
| 
						 | 
					@ -149,6 +161,67 @@ class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
                bias=False,
 | 
					                bias=False,
 | 
				
			||||||
                wt_dtype=self.dtype,
 | 
					                wt_dtype=self.dtype,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            hidden_states = self.unsqueeze(hidden_states, axis=0)
 | 
				
			||||||
 | 
					            if mode == "prefill":
 | 
				
			||||||
 | 
					                query_states_to_concat = []
 | 
				
			||||||
 | 
					                key_states_to_concat = []
 | 
				
			||||||
 | 
					                value_states_to_concat = []
 | 
				
			||||||
 | 
					                for i in range(self.n_splits_linear):
 | 
				
			||||||
 | 
					                    sub_hidden_states = self.slice(hidden_states,
 | 
				
			||||||
 | 
					                                                   begin=[0, 0, i * groupsize],
 | 
				
			||||||
 | 
					                                                   end=[1, seq_len, (i + 1) * groupsize])
 | 
				
			||||||
 | 
					                    query_states_to_concat.append(
 | 
				
			||||||
 | 
					                        self.linear(
 | 
				
			||||||
 | 
					                            sub_hidden_states,
 | 
				
			||||||
 | 
					                            num_heads * head_dim,
 | 
				
			||||||
 | 
					                            groupsize,
 | 
				
			||||||
 | 
					                            bias=False,
 | 
				
			||||||
 | 
					                            wt_dtype=self.dtype,
 | 
				
			||||||
 | 
					                            scale_factor=(self.group_size == 0)
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    key_states_to_concat.append(
 | 
				
			||||||
 | 
					                        self.linear(
 | 
				
			||||||
 | 
					                            sub_hidden_states,
 | 
				
			||||||
 | 
					                            num_key_value_heads * head_dim,
 | 
				
			||||||
 | 
					                            groupsize,
 | 
				
			||||||
 | 
					                            bias=False,
 | 
				
			||||||
 | 
					                            wt_dtype=self.dtype,
 | 
				
			||||||
 | 
					                            scale_factor=(self.group_size == 0)
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    value_states_to_concat.append(
 | 
				
			||||||
 | 
					                        self.linear(
 | 
				
			||||||
 | 
					                            sub_hidden_states,
 | 
				
			||||||
 | 
					                            num_key_value_heads * head_dim,
 | 
				
			||||||
 | 
					                            groupsize,
 | 
				
			||||||
 | 
					                            bias=False,
 | 
				
			||||||
 | 
					                            wt_dtype=self.dtype,
 | 
				
			||||||
 | 
					                            scale_factor=(self.group_size == 0)
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                query_states = sum(query_states_to_concat)
 | 
				
			||||||
 | 
					                key_states = sum(key_states_to_concat)
 | 
				
			||||||
 | 
					                value_states = sum(value_states_to_concat)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
 | 
				
			||||||
 | 
					                                                    hidden_size, self.n_splits_linear,
 | 
				
			||||||
 | 
					                                                    wt_dtype=self.dtype,
 | 
				
			||||||
 | 
					                                                    scale_factor=(self.group_size == 0))
 | 
				
			||||||
 | 
					                key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
 | 
				
			||||||
 | 
					                                                  hidden_size, self.n_splits_linear,
 | 
				
			||||||
 | 
					                                                  wt_dtype=self.dtype,
 | 
				
			||||||
 | 
					                                                  scale_factor=(self.group_size == 0))
 | 
				
			||||||
 | 
					                value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
 | 
				
			||||||
 | 
					                                                    hidden_size, self.n_splits_linear,
 | 
				
			||||||
 | 
					                                                    wt_dtype=self.dtype,
 | 
				
			||||||
 | 
					                                                    scale_factor=(self.group_size == 0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if q_bias is not None:
 | 
				
			||||||
 | 
					            query_states = query_states + q_bias
 | 
				
			||||||
 | 
					        if k_bias is not None:
 | 
				
			||||||
 | 
					            key_states = key_states + k_bias
 | 
				
			||||||
        if v_bias is not None:
 | 
					        if v_bias is not None:
 | 
				
			||||||
            value_states = value_states + v_bias
 | 
					            value_states = value_states + v_bias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -215,23 +288,100 @@ class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
        attn_output = self.transpose(attn_output, [0, 2, 1, 3])
 | 
					        attn_output = self.transpose(attn_output, [0, 2, 1, 3])
 | 
				
			||||||
        attn_output = self.reshape(attn_output, [1, seq_len, hidden_size])
 | 
					        attn_output = self.reshape(attn_output, [1, seq_len, hidden_size])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.n_splits_linear == 1:
 | 
				
			||||||
            attn_output = self.linear(
 | 
					            attn_output = self.linear(
 | 
				
			||||||
                attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
 | 
					                attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if mode == "prefill":
 | 
				
			||||||
 | 
					                attn_output_to_concat = []
 | 
				
			||||||
 | 
					                for i in range(self.n_splits_linear):
 | 
				
			||||||
 | 
					                    sub_attn_output = self.slice(attn_output,
 | 
				
			||||||
 | 
					                                                 begin=[0, 0, i * groupsize],
 | 
				
			||||||
 | 
					                                                 end=[1, seq_len, (i + 1) * groupsize])
 | 
				
			||||||
 | 
					                    attn_output_to_concat.append(
 | 
				
			||||||
 | 
					                        self.linear(
 | 
				
			||||||
 | 
					                            sub_attn_output, hidden_size, groupsize, bias=False,
 | 
				
			||||||
 | 
					                            wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                attn_output = sum(attn_output_to_concat)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
 | 
				
			||||||
 | 
					                                                   self.n_splits_linear, wt_dtype=self.dtype,
 | 
				
			||||||
 | 
					                                                   scale_factor=(self.group_size == 0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return attn_output, new_key_states, new_value_states
 | 
					        return attn_output, new_key_states, new_value_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def mlp(self, hidden_states):
 | 
					    def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
 | 
				
			||||||
 | 
					        if self.n_splits_linear == 1:
 | 
				
			||||||
            mm1 = self.linear(
 | 
					            mm1 = self.linear(
 | 
				
			||||||
            hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
					                hidden_states, self.intermediate_size, self.hidden_size, bias=False,
 | 
				
			||||||
 | 
					                wt_dtype=self.dtype
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            mm2 = self.linear(
 | 
					            mm2 = self.linear(
 | 
				
			||||||
            hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
					                hidden_states, self.intermediate_size, self.hidden_size, bias=False,
 | 
				
			||||||
 | 
					                wt_dtype=self.dtype
 | 
				
			||||||
            )  # type: ignore[attr-defined]
 | 
					            )  # type: ignore[attr-defined]
 | 
				
			||||||
            mm1 = self.eltwise_mul(self.swish(mm1), mm2)  # type: ignore[attr-defined]
 | 
					            mm1 = self.eltwise_mul(self.swish(mm1), mm2)  # type: ignore[attr-defined]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
 | 
				
			||||||
 | 
					            if mode == "prefill":
 | 
				
			||||||
 | 
					                gate_up_groupsize = self.hidden_size // self.n_splits_linear
 | 
				
			||||||
 | 
					                mm1_to_concat = []
 | 
				
			||||||
 | 
					                mm2_to_concat = []
 | 
				
			||||||
 | 
					                for i in range(self.n_splits_linear):
 | 
				
			||||||
 | 
					                    sub_hidden_states = self.slice(hidden_states,
 | 
				
			||||||
 | 
					                                                   begin=[0, 0, i * gate_up_groupsize],
 | 
				
			||||||
 | 
					                                                   end=[1, seq_len, (i + 1) * gate_up_groupsize])
 | 
				
			||||||
 | 
					                    mm1_to_concat.append(
 | 
				
			||||||
 | 
					                        self.linear(
 | 
				
			||||||
 | 
					                            sub_hidden_states, self.intermediate_size, gate_up_groupsize,
 | 
				
			||||||
 | 
					                            bias=False,
 | 
				
			||||||
 | 
					                            wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    mm2_to_concat.append(
 | 
				
			||||||
 | 
					                        self.linear(
 | 
				
			||||||
 | 
					                            sub_hidden_states, self.intermediate_size, gate_up_groupsize,
 | 
				
			||||||
 | 
					                            bias=False,
 | 
				
			||||||
 | 
					                            wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                mm1 = sum(mm1_to_concat)
 | 
				
			||||||
 | 
					                mm2 = sum(mm2_to_concat)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
 | 
				
			||||||
 | 
					                                           self.n_splits_linear, wt_dtype=self.dtype,
 | 
				
			||||||
 | 
					                                           scale_factor=(self.group_size == 0))
 | 
				
			||||||
 | 
					                mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
 | 
				
			||||||
 | 
					                                           self.n_splits_linear, wt_dtype=self.dtype,
 | 
				
			||||||
 | 
					                                           scale_factor=(self.group_size == 0))
 | 
				
			||||||
 | 
					            mm1 = self.eltwise_mul(self.swish(mm1), mm2)  # type: ignore[attr-defined]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.n_splits_down_proj == 1:
 | 
				
			||||||
            hidden_states = self.linear(
 | 
					            hidden_states = self.linear(
 | 
				
			||||||
                mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype
 | 
					                mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
 | 
				
			||||||
 | 
					            if mode == "prefill":
 | 
				
			||||||
 | 
					                down_groupsize = self.intermediate_size // self.n_splits_down_proj
 | 
				
			||||||
 | 
					                hidden_states_to_concat = []
 | 
				
			||||||
 | 
					                for i in range(self.n_splits_down_proj):
 | 
				
			||||||
 | 
					                    sub_mm1 = self.slice(mm1, begin=[0, 0, i * down_groupsize],
 | 
				
			||||||
 | 
					                                         end=[1, seq_len, (i + 1) * down_groupsize])
 | 
				
			||||||
 | 
					                    hidden_states_to_concat.append(
 | 
				
			||||||
 | 
					                        self.linear(
 | 
				
			||||||
 | 
					                            sub_mm1, self.hidden_size, down_groupsize, bias=False,
 | 
				
			||||||
 | 
					                            wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                hidden_states = sum(hidden_states_to_concat)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
 | 
				
			||||||
 | 
					                                                     self.n_splits_down_proj, wt_dtype=self.dtype,
 | 
				
			||||||
 | 
					                                                     scale_factor=(self.group_size == 0))
 | 
				
			||||||
        return hidden_states
 | 
					        return hidden_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def layer_norm(self, hidden_states, layernorm_weight):
 | 
					    def layer_norm(self, hidden_states, layernorm_weight):
 | 
				
			||||||
| 
						 | 
					@ -341,6 +491,19 @@ class LLMBaseNNFactory(NNFactory):
 | 
				
			||||||
        self.linear_ops.append(op)
 | 
					        self.linear_ops.append(op)
 | 
				
			||||||
        return op
 | 
					        return op
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def dq_split_linear(self,
 | 
				
			||||||
 | 
					                        input_node: ctypes._Pointer,
 | 
				
			||||||
 | 
					                        output_channels: int,
 | 
				
			||||||
 | 
					                        input_channels: int,
 | 
				
			||||||
 | 
					                        n_splits: int,
 | 
				
			||||||
 | 
					                        act_dtype: npt.DTypeLike = np.float16,
 | 
				
			||||||
 | 
					                        wt_dtype: npt.DTypeLike = np.float16,
 | 
				
			||||||
 | 
					                        scale_factor: bool = False):
 | 
				
			||||||
 | 
					        op = super().dq_split_linear(input_node, n_splits, output_channels, input_channels,
 | 
				
			||||||
 | 
					                                     False, act_dtype, wt_dtype, scale_factor)
 | 
				
			||||||
 | 
					        self.linear_ops.append(op)
 | 
				
			||||||
 | 
					        return op
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def parameter(self, shape):
 | 
					    def parameter(self, shape):
 | 
				
			||||||
        invalidInputError(False,
 | 
					        invalidInputError(False,
 | 
				
			||||||
                          ("parameter should not be called directly, "
 | 
					                          ("parameter should not be called directly, "
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -42,7 +42,27 @@ from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
 | 
				
			||||||
from ipex_llm.transformers.npu_models.common import reshape_lm_head_input
 | 
					from ipex_llm.transformers.npu_models.common import reshape_lm_head_input
 | 
				
			||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
 | 
					from transformers.modeling_outputs import CausalLMOutputWithPast
 | 
				
			||||||
from torch.nn import CrossEntropyLoss
 | 
					from torch.nn import CrossEntropyLoss
 | 
				
			||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
 | 
					from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2Attention
 | 
				
			||||||
 | 
					from ipex_llm.utils.common.log4Error import invalidInputError
 | 
				
			||||||
 | 
					from ipex_llm.transformers.npu_models.common import split_linear
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down_proj=2):
 | 
				
			||||||
 | 
					    attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"]
 | 
				
			||||||
 | 
					    mlp_module_names = ["down_proj", "up_proj", "gate_proj"]
 | 
				
			||||||
 | 
					    if isinstance(module, Qwen2Attention):
 | 
				
			||||||
 | 
					        for name in attn_module_names:
 | 
				
			||||||
 | 
					            setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
 | 
				
			||||||
 | 
					                                                            n_splits=n_splits_hidden_size))
 | 
				
			||||||
 | 
					            delattr(module, name)
 | 
				
			||||||
 | 
					    elif isinstance(module, Qwen2MLP):
 | 
				
			||||||
 | 
					        for name in mlp_module_names:
 | 
				
			||||||
 | 
					            n_splits_mlp = n_splits_hidden_size
 | 
				
			||||||
 | 
					            if name == 'down_proj':
 | 
				
			||||||
 | 
					                n_splits_mlp = n_splits_down_proj
 | 
				
			||||||
 | 
					            setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
 | 
				
			||||||
 | 
					                                                            n_splits=n_splits_mlp))
 | 
				
			||||||
 | 
					            delattr(module, name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def split_mlp_down_proj(module: torch.nn.Module):
 | 
					def split_mlp_down_proj(module: torch.nn.Module):
 | 
				
			||||||
| 
						 | 
					@ -94,12 +114,18 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
 | 
				
			||||||
        device: str = "NPU",
 | 
					        device: str = "NPU",
 | 
				
			||||||
        rms_norm_eps,
 | 
					        rms_norm_eps,
 | 
				
			||||||
        intermediate_size,
 | 
					        intermediate_size,
 | 
				
			||||||
 | 
					        n_splits_linear: int = 1,
 | 
				
			||||||
 | 
					        n_splits_down_proj: int = 1,
 | 
				
			||||||
 | 
					        group_size: int = 0
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        super().__init__(max_seq_len=max_seq_len,
 | 
					        super().__init__(max_seq_len=max_seq_len,
 | 
				
			||||||
                         transpose_value=transpose_value,
 | 
					                         transpose_value=transpose_value,
 | 
				
			||||||
                         dtype=dtype,
 | 
					                         dtype=dtype,
 | 
				
			||||||
                         profile=profile,
 | 
					                         profile=profile,
 | 
				
			||||||
                         device=device)
 | 
					                         device=device,
 | 
				
			||||||
 | 
					                         n_splits_linear=n_splits_linear,
 | 
				
			||||||
 | 
					                         n_splits_down_proj=n_splits_down_proj,
 | 
				
			||||||
 | 
					                         group_size=group_size)
 | 
				
			||||||
        self.max_seq_len = max_seq_len
 | 
					        self.max_seq_len = max_seq_len
 | 
				
			||||||
        self.intermediate_size = intermediate_size
 | 
					        self.intermediate_size = intermediate_size
 | 
				
			||||||
        self.dtype = dtype
 | 
					        self.dtype = dtype
 | 
				
			||||||
| 
						 | 
					@ -221,32 +247,9 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
 | 
				
			||||||
            new_key_states = self.convert_to_fp16(curr_key_values[i][0])
 | 
					            new_key_states = self.convert_to_fp16(curr_key_values[i][0])
 | 
				
			||||||
            new_value_states = self.convert_to_fp16(curr_key_values[i][1])
 | 
					            new_value_states = self.convert_to_fp16(curr_key_values[i][1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        print("start compiling")
 | 
					        print(f"{mode} start compiling")
 | 
				
			||||||
        self.compile()
 | 
					        self.compile()
 | 
				
			||||||
        print("end compiling")
 | 
					        print(f"{mode} end compiling")
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def mlp(self, hidden_states, seq_len):
 | 
					 | 
				
			||||||
        mm1 = self.linear(
 | 
					 | 
				
			||||||
            hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        mm2 = self.linear(
 | 
					 | 
				
			||||||
            hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
 | 
					 | 
				
			||||||
        )  # type: ignore[attr-defined]
 | 
					 | 
				
			||||||
        mm1 = self.eltwise_mul(self.swish(mm1), mm2)  # type: ignore[attr-defined]
 | 
					 | 
				
			||||||
        if self.intermediate_size == 18944:
 | 
					 | 
				
			||||||
            # for qwen2-7b
 | 
					 | 
				
			||||||
            mm1_0 = self.slice(mm1, begin=[0, 0, 0], end=[1, seq_len, 9472])
 | 
					 | 
				
			||||||
            mm1_1 = self.slice(mm1, begin=[0, 0, 9472], end=[1, seq_len, 18944])
 | 
					 | 
				
			||||||
            hidden_states_0 = self.linear(mm1_0, self.hidden_size, 9472,
 | 
					 | 
				
			||||||
                                          bias=False, wt_dtype=self.dtype)
 | 
					 | 
				
			||||||
            hidden_states_1 = self.linear(mm1_1, self.hidden_size, 9472,
 | 
					 | 
				
			||||||
                                          bias=False, wt_dtype=self.dtype)
 | 
					 | 
				
			||||||
            hidden_states = hidden_states_0 + hidden_states_1
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            hidden_states = self.linear(
 | 
					 | 
				
			||||||
                mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        return hidden_states
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def build_decoder(
 | 
					    def build_decoder(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
| 
						 | 
					@ -285,7 +288,7 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
 | 
				
			||||||
        hidden_states = self.eltwise_add(residual, attn_output)
 | 
					        hidden_states = self.eltwise_add(residual, attn_output)
 | 
				
			||||||
        residual = hidden_states
 | 
					        residual = hidden_states
 | 
				
			||||||
        hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
 | 
					        hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
 | 
				
			||||||
        hidden_states = self.mlp(hidden_states, self.seq_len)
 | 
					        hidden_states = self.mlp(hidden_states, self.seq_len, self.mode)
 | 
				
			||||||
        hidden_states = self.eltwise_add(residual, hidden_states)
 | 
					        hidden_states = self.eltwise_add(residual, hidden_states)
 | 
				
			||||||
        hidden_states = self.convert_to_fp16(hidden_states)
 | 
					        hidden_states = self.convert_to_fp16(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -314,6 +317,9 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
 | 
				
			||||||
        max_seq_len: int = 1024,
 | 
					        max_seq_len: int = 1024,
 | 
				
			||||||
        transpose_value: bool = False,
 | 
					        transpose_value: bool = False,
 | 
				
			||||||
        do_print: bool = False,
 | 
					        do_print: bool = False,
 | 
				
			||||||
 | 
					        n_splits_linear: int = 1,
 | 
				
			||||||
 | 
					        n_splits_down_proj: int = 1,
 | 
				
			||||||
 | 
					        group_size: int = 0,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -323,6 +329,10 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
 | 
				
			||||||
        for w in parameters:
 | 
					        for w in parameters:
 | 
				
			||||||
            if isinstance(w, tuple):  # from QuantizedLinear
 | 
					            if isinstance(w, tuple):  # from QuantizedLinear
 | 
				
			||||||
                op_parameters.append((w[0].numpy(), w[1].numpy()))
 | 
					                op_parameters.append((w[0].numpy(), w[1].numpy()))
 | 
				
			||||||
 | 
					            elif w.dtype in [torch.int8, torch.uint8]:    # QuantizedLinear weight
 | 
				
			||||||
 | 
					                op_parameters.append(w.numpy())
 | 
				
			||||||
 | 
					            elif isinstance(w, np.ndarray):     # scale
 | 
				
			||||||
 | 
					                op_parameters.append(w)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                op_parameters.append(w.to(torch.float16).numpy())
 | 
					                op_parameters.append(w.to(torch.float16).numpy())
 | 
				
			||||||
        self.op_parameters = op_parameters
 | 
					        self.op_parameters = op_parameters
 | 
				
			||||||
| 
						 | 
					@ -331,6 +341,10 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
 | 
				
			||||||
        self.transpose_value = transpose_value
 | 
					        self.transpose_value = transpose_value
 | 
				
			||||||
        if isinstance(parameters[0], tuple):
 | 
					        if isinstance(parameters[0], tuple):
 | 
				
			||||||
            np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
 | 
					            np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
 | 
				
			||||||
 | 
					        elif parameters[0].dtype == torch.int8:
 | 
				
			||||||
 | 
					            np_dtype = np.int8
 | 
				
			||||||
 | 
					        elif parameters[0].dtype == torch.uint8:
 | 
				
			||||||
 | 
					            np_dtype = np.uint8
 | 
				
			||||||
        else:  # FP16 Linear
 | 
					        else:  # FP16 Linear
 | 
				
			||||||
            np_dtype = np.float16
 | 
					            np_dtype = np.float16
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -368,6 +382,9 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
 | 
				
			||||||
                mode="decode",
 | 
					                mode="decode",
 | 
				
			||||||
                transpose_value=self.transpose_value,
 | 
					                transpose_value=self.transpose_value,
 | 
				
			||||||
                dtype=np_dtype,
 | 
					                dtype=np_dtype,
 | 
				
			||||||
 | 
					                n_splits_linear=n_splits_linear,
 | 
				
			||||||
 | 
					                n_splits_down_proj=n_splits_down_proj,
 | 
				
			||||||
 | 
					                group_size=group_size
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            self.backend_decoders.append(decoder)
 | 
					            self.backend_decoders.append(decoder)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -450,6 +467,9 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module):
 | 
				
			||||||
        intermediate_size,
 | 
					        intermediate_size,
 | 
				
			||||||
        max_seq_len: int = 128,
 | 
					        max_seq_len: int = 128,
 | 
				
			||||||
        transpose_value: bool = False,
 | 
					        transpose_value: bool = False,
 | 
				
			||||||
 | 
					        n_splits_linear: int = 1,
 | 
				
			||||||
 | 
					        n_splits_down_proj: int = 1,
 | 
				
			||||||
 | 
					        group_size: int = 0,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.op_parameters = parameters
 | 
					        self.op_parameters = parameters
 | 
				
			||||||
| 
						 | 
					@ -478,6 +498,9 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module):
 | 
				
			||||||
            mode="prefill",
 | 
					            mode="prefill",
 | 
				
			||||||
            transpose_value=self.transpose_value,
 | 
					            transpose_value=self.transpose_value,
 | 
				
			||||||
            dtype=np_dtype,
 | 
					            dtype=np_dtype,
 | 
				
			||||||
 | 
					            n_splits_linear=n_splits_linear,
 | 
				
			||||||
 | 
					            n_splits_down_proj=n_splits_down_proj,
 | 
				
			||||||
 | 
					            group_size=group_size
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.layer_norm_0 = layer_norm_0
 | 
					        self.layer_norm_0 = layer_norm_0
 | 
				
			||||||
        self.layer_norm_1 = layer_norm_1
 | 
					        self.layer_norm_1 = layer_norm_1
 | 
				
			||||||
| 
						 | 
					@ -554,6 +577,7 @@ def run_decode(
 | 
				
			||||||
    head_dim = model.model.layers[layer_start].self_attn.head_dim
 | 
					    head_dim = model.model.layers[layer_start].self_attn.head_dim
 | 
				
			||||||
    rms_norm_eps = model.config.rms_norm_eps
 | 
					    rms_norm_eps = model.config.rms_norm_eps
 | 
				
			||||||
    intermediate_size = model.config.intermediate_size
 | 
					    intermediate_size = model.config.intermediate_size
 | 
				
			||||||
 | 
					    group_size = getattr(model.config, "group_size", 0)
 | 
				
			||||||
    layer_weights = []
 | 
					    layer_weights = []
 | 
				
			||||||
    input_layer_norm_weights = []
 | 
					    input_layer_norm_weights = []
 | 
				
			||||||
    post_attn_layernorm_weights = []
 | 
					    post_attn_layernorm_weights = []
 | 
				
			||||||
| 
						 | 
					@ -561,34 +585,56 @@ def run_decode(
 | 
				
			||||||
    k_biases = []
 | 
					    k_biases = []
 | 
				
			||||||
    v_biases = []
 | 
					    v_biases = []
 | 
				
			||||||
    layer_indexs = range(layer_start, layer_end)
 | 
					    layer_indexs = range(layer_start, layer_end)
 | 
				
			||||||
 | 
					    n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
				
			||||||
 | 
					    n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
				
			||||||
    for layer_idx in layer_indexs:
 | 
					    for layer_idx in layer_indexs:
 | 
				
			||||||
        curr_layer = model.model.layers[layer_idx]
 | 
					        curr_layer = model.model.layers[layer_idx]
 | 
				
			||||||
        attn_layer = curr_layer.self_attn
 | 
					        attn_layer = curr_layer.self_attn
 | 
				
			||||||
        mlp_layer = curr_layer.mlp
 | 
					        mlp_layer = curr_layer.mlp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if model.config.intermediate_size == 8960:
 | 
					        weights = []
 | 
				
			||||||
            # for qwen2-1.5b
 | 
					        if n_splits_linear == 1:
 | 
				
			||||||
            weights = [
 | 
					            for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
				
			||||||
                (attn_layer.q_proj.weight, attn_layer.q_proj.scale),
 | 
					                               attn_layer.v_proj_dq_list):
 | 
				
			||||||
                (attn_layer.k_proj.weight, attn_layer.k_proj.scale),
 | 
					                weights.append((q.weight, q.scale))
 | 
				
			||||||
                (attn_layer.v_proj.weight, attn_layer.v_proj.scale),
 | 
					                weights.append((k.weight, k.scale))
 | 
				
			||||||
                (attn_layer.o_proj.weight, attn_layer.o_proj.scale),
 | 
					                weights.append((v.weight, v.scale))
 | 
				
			||||||
                (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
 | 
					
 | 
				
			||||||
                (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
					            for l in attn_layer.o_proj_dq_list:
 | 
				
			||||||
                (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
 | 
					                weights.append((l.weight, l.scale))
 | 
				
			||||||
            ]
 | 
					        else:
 | 
				
			||||||
        elif model.config.intermediate_size == 18944:
 | 
					            for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
				
			||||||
            # for qwen2-7b
 | 
					                               attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list]:
 | 
				
			||||||
            weights = [
 | 
					                l_weights = []
 | 
				
			||||||
                (attn_layer.q_proj.weight, attn_layer.q_proj.scale),
 | 
					                scales = []
 | 
				
			||||||
                (attn_layer.k_proj.weight, attn_layer.k_proj.scale),
 | 
					                for l in layer_list:
 | 
				
			||||||
                (attn_layer.v_proj.weight, attn_layer.v_proj.scale),
 | 
					                    l_weights.append(l.weight)
 | 
				
			||||||
                (attn_layer.o_proj.weight, attn_layer.o_proj.scale),
 | 
					                    scales.append(l.scale)
 | 
				
			||||||
                (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
 | 
					                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
				
			||||||
                (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
					
 | 
				
			||||||
                (mlp_layer.down_proj_0.weight, mlp_layer.down_proj_0.scale),
 | 
					        if n_splits_linear == 1:
 | 
				
			||||||
                (mlp_layer.down_proj_1.weight, mlp_layer.down_proj_1.scale)
 | 
					            for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
 | 
				
			||||||
            ]
 | 
					                weights.append((g.weight, g.scale))
 | 
				
			||||||
 | 
					                weights.append((u.weight, u.scale))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            for layer_list in [mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
 | 
				
			||||||
 | 
					                l_weights = []
 | 
				
			||||||
 | 
					                scales = []
 | 
				
			||||||
 | 
					                for l in layer_list:
 | 
				
			||||||
 | 
					                    l_weights.append(l.weight)
 | 
				
			||||||
 | 
					                    scales.append(l.scale)
 | 
				
			||||||
 | 
					                weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if n_splits_down_proj == 1:
 | 
				
			||||||
 | 
					            for l in mlp_layer.down_proj_dq_list:
 | 
				
			||||||
 | 
					                weights.append((l.weight, l.scale))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            l_weights = []
 | 
				
			||||||
 | 
					            scales = []
 | 
				
			||||||
 | 
					            for l in mlp_layer.down_proj_dq_list:
 | 
				
			||||||
 | 
					                l_weights.append(l.weight)
 | 
				
			||||||
 | 
					                scales.append(l.scale)
 | 
				
			||||||
 | 
					            weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
					        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
				
			||||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
					        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
				
			||||||
| 
						 | 
					@ -598,9 +644,9 @@ def run_decode(
 | 
				
			||||||
        layer_weights.extend(weights)
 | 
					        layer_weights.extend(weights)
 | 
				
			||||||
        input_layer_norm_weights.append(layer_norm_0)
 | 
					        input_layer_norm_weights.append(layer_norm_0)
 | 
				
			||||||
        post_attn_layernorm_weights.append(layer_norm_1)
 | 
					        post_attn_layernorm_weights.append(layer_norm_1)
 | 
				
			||||||
        q_biases.append(attn_layer.q_proj.bias.to(torch.float16))
 | 
					        q_biases.append(attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16))
 | 
				
			||||||
        k_biases.append(attn_layer.k_proj.bias.to(torch.float16))
 | 
					        k_biases.append(attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16))
 | 
				
			||||||
        v_biases.append(attn_layer.v_proj.bias.to(torch.float16))
 | 
					        v_biases.append(attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    multi_decoder = FusedQwenLowBitMultiDecoderlayer(
 | 
					    multi_decoder = FusedQwenLowBitMultiDecoderlayer(
 | 
				
			||||||
        parameters=layer_weights,
 | 
					        parameters=layer_weights,
 | 
				
			||||||
| 
						 | 
					@ -621,6 +667,9 @@ def run_decode(
 | 
				
			||||||
        max_seq_len=max_seq_len,
 | 
					        max_seq_len=max_seq_len,
 | 
				
			||||||
        transpose_value=transpose_value_cache,
 | 
					        transpose_value=transpose_value_cache,
 | 
				
			||||||
        do_print=False,
 | 
					        do_print=False,
 | 
				
			||||||
 | 
					        n_splits_linear=n_splits_linear,
 | 
				
			||||||
 | 
					        n_splits_down_proj=n_splits_down_proj,
 | 
				
			||||||
 | 
					        group_size=group_size
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    dist.barrier()
 | 
					    dist.barrier()
 | 
				
			||||||
| 
						 | 
					@ -703,11 +752,15 @@ class DecodeRunner:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.forward_signal = torch.tensor(0, dtype=torch.int)
 | 
					        self.forward_signal = torch.tensor(0, dtype=torch.int)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        n_layers_per_rank = num_layers // (world_size - 1)
 | 
				
			||||||
 | 
					        if num_layers % (world_size - 1) > 0:
 | 
				
			||||||
 | 
					            n_layers_per_rank += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for rank in range(1, world_size):
 | 
					        for rank in range(1, world_size):
 | 
				
			||||||
            input_q = mp.Queue()
 | 
					            input_q = mp.Queue()
 | 
				
			||||||
            output_q = mp.Queue()
 | 
					            output_q = mp.Queue()
 | 
				
			||||||
            start_layer = (rank - 1) * (num_layers // (world_size - 1))
 | 
					            start_layer = (rank - 1) * n_layers_per_rank
 | 
				
			||||||
            end_layer = (rank) * (num_layers // (world_size - 1))
 | 
					            end_layer = (rank) * n_layers_per_rank
 | 
				
			||||||
            if rank == world_size - 1:
 | 
					            if rank == world_size - 1:
 | 
				
			||||||
                end_layer = num_layers
 | 
					                end_layer = num_layers
 | 
				
			||||||
            p = mp.Process(
 | 
					            p = mp.Process(
 | 
				
			||||||
| 
						 | 
					@ -787,39 +840,34 @@ def run_prefill(
 | 
				
			||||||
    head_dim = model.model.layers[layer_start].self_attn.head_dim
 | 
					    head_dim = model.model.layers[layer_start].self_attn.head_dim
 | 
				
			||||||
    rms_norm_eps = model.config.rms_norm_eps
 | 
					    rms_norm_eps = model.config.rms_norm_eps
 | 
				
			||||||
    intermediate_size = model.config.intermediate_size
 | 
					    intermediate_size = model.config.intermediate_size
 | 
				
			||||||
 | 
					    group_size = getattr(model.config, "group_size", 0)
 | 
				
			||||||
    deocderlayers = []
 | 
					    deocderlayers = []
 | 
				
			||||||
    layer_weights = []
 | 
					    layer_weights = []
 | 
				
			||||||
    input_layer_norm_weights = []
 | 
					    input_layer_norm_weights = []
 | 
				
			||||||
    post_attn_layernorm_weights = []
 | 
					    post_attn_layernorm_weights = []
 | 
				
			||||||
    layer_indexs = range(layer_start, layer_end)
 | 
					    layer_indexs = range(layer_start, layer_end)
 | 
				
			||||||
 | 
					    n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
 | 
				
			||||||
 | 
					    n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
 | 
				
			||||||
    for layer_idx in layer_indexs:
 | 
					    for layer_idx in layer_indexs:
 | 
				
			||||||
        curr_layer = model.model.layers[layer_idx]
 | 
					        curr_layer = model.model.layers[layer_idx]
 | 
				
			||||||
        attn_layer = curr_layer.self_attn
 | 
					        attn_layer = curr_layer.self_attn
 | 
				
			||||||
        mlp_layer = curr_layer.mlp
 | 
					        mlp_layer = curr_layer.mlp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if model.config.intermediate_size == 8960:
 | 
					        weights = []
 | 
				
			||||||
            # for qwen2-1.5b
 | 
					
 | 
				
			||||||
            weights = [
 | 
					        for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
 | 
				
			||||||
                (attn_layer.q_proj.weight, attn_layer.q_proj.scale),
 | 
					                           attn_layer.v_proj_dq_list):
 | 
				
			||||||
                (attn_layer.k_proj.weight, attn_layer.k_proj.scale),
 | 
					            weights.append((q.weight, q.scale))
 | 
				
			||||||
                (attn_layer.v_proj.weight, attn_layer.v_proj.scale),
 | 
					            weights.append((k.weight, k.scale))
 | 
				
			||||||
                (attn_layer.o_proj.weight, attn_layer.o_proj.scale),
 | 
					            weights.append((v.weight, v.scale))
 | 
				
			||||||
                (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
 | 
					
 | 
				
			||||||
                (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
					        for l in attn_layer.o_proj_dq_list:
 | 
				
			||||||
                (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
 | 
					            weights.append((l.weight, l.scale))
 | 
				
			||||||
            ]
 | 
					        for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
 | 
				
			||||||
        elif model.config.intermediate_size == 18944:
 | 
					            weights.append((g.weight, g.scale))
 | 
				
			||||||
            # for qwen2-7b
 | 
					            weights.append((u.weight, u.scale))
 | 
				
			||||||
            weights = [
 | 
					        for l in mlp_layer.down_proj_dq_list:
 | 
				
			||||||
                (attn_layer.q_proj.weight, attn_layer.q_proj.scale),
 | 
					            weights.append((l.weight, l.scale))
 | 
				
			||||||
                (attn_layer.k_proj.weight, attn_layer.k_proj.scale),
 | 
					 | 
				
			||||||
                (attn_layer.v_proj.weight, attn_layer.v_proj.scale),
 | 
					 | 
				
			||||||
                (attn_layer.o_proj.weight, attn_layer.o_proj.scale),
 | 
					 | 
				
			||||||
                (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
 | 
					 | 
				
			||||||
                (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
 | 
					 | 
				
			||||||
                (mlp_layer.down_proj_0.weight, mlp_layer.down_proj_0.scale),
 | 
					 | 
				
			||||||
                (mlp_layer.down_proj_1.weight, mlp_layer.down_proj_1.scale)
 | 
					 | 
				
			||||||
            ]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
					        cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
 | 
				
			||||||
        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
					        cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
 | 
				
			||||||
| 
						 | 
					@ -835,14 +883,17 @@ def run_prefill(
 | 
				
			||||||
            cached_sin=cached_sin,
 | 
					            cached_sin=cached_sin,
 | 
				
			||||||
            layer_norm_0=layer_norm_0,
 | 
					            layer_norm_0=layer_norm_0,
 | 
				
			||||||
            layer_norm_1=layer_norm_1,
 | 
					            layer_norm_1=layer_norm_1,
 | 
				
			||||||
            q_bias=attn_layer.q_proj.bias.to(torch.float16),
 | 
					            q_bias=attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16),
 | 
				
			||||||
            k_bias=attn_layer.k_proj.bias.to(torch.float16),
 | 
					            k_bias=attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16),
 | 
				
			||||||
            v_bias=attn_layer.v_proj.bias.to(torch.float16),
 | 
					            v_bias=attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16),
 | 
				
			||||||
            layer_idx=layer_idx,
 | 
					            layer_idx=layer_idx,
 | 
				
			||||||
            rms_norm_eps=rms_norm_eps,
 | 
					            rms_norm_eps=rms_norm_eps,
 | 
				
			||||||
            intermediate_size=intermediate_size,
 | 
					            intermediate_size=intermediate_size,
 | 
				
			||||||
            max_seq_len=max_output_len,
 | 
					            max_seq_len=max_output_len,
 | 
				
			||||||
            transpose_value=transpose_value_cache,
 | 
					            transpose_value=transpose_value_cache,
 | 
				
			||||||
 | 
					            n_splits_linear=n_splits_linear,
 | 
				
			||||||
 | 
					            n_splits_down_proj=n_splits_down_proj,
 | 
				
			||||||
 | 
					            group_size=group_size
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        layer_weights.extend(weights)
 | 
					        layer_weights.extend(weights)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue