[NPU] Support save npu quantized model without npu dependency (#12647)
* support save awq * load quantized model & save npu compiled model * fix style * update * fix dll load issue * update error message * fix style
This commit is contained in:
		
							parent
							
								
									502461d836
								
							
						
					
					
						commit
						fae73eee79
					
				
					 5 changed files with 203 additions and 144 deletions
				
			
		| 
						 | 
				
			
			@ -27,7 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
 | 
			
		|||
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.utils import logger, load_imatrix_data
 | 
			
		||||
from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post
 | 
			
		||||
from ipex_llm.transformers.npu_models.convert import optimize_llm
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_flash_attn_import(filename: str) -> List[str]:
 | 
			
		||||
| 
						 | 
				
			
			@ -207,8 +207,6 @@ class _BaseAutoModelClass:
 | 
			
		|||
            model = model.eval()
 | 
			
		||||
            logger.info(f"Finish to convert model")
 | 
			
		||||
        else:
 | 
			
		||||
            from intel_npu_acceleration_library.compiler import create_npu_kernels
 | 
			
		||||
 | 
			
		||||
            if optimize_model:
 | 
			
		||||
                invalidInputError(
 | 
			
		||||
                    max_prompt_len < max_context_len,
 | 
			
		||||
| 
						 | 
				
			
			@ -232,11 +230,14 @@ class _BaseAutoModelClass:
 | 
			
		|||
                    "convert_model": convert_model,
 | 
			
		||||
                    "save_directory": save_directory,
 | 
			
		||||
                    "fuse_layers": fuse_layers,
 | 
			
		||||
                    "imatrix_data": imatrix_data
 | 
			
		||||
                    "imatrix_data": imatrix_data,
 | 
			
		||||
                    "skip_npu_logic": mock_device == "dummy",
 | 
			
		||||
                }
 | 
			
		||||
                # Dummy will skip npu related logic and save the quantized model
 | 
			
		||||
                if mock_device == "dummy":
 | 
			
		||||
                    model.save_low_bit = types.MethodType(save_low_bit, model)
 | 
			
		||||
                model = cls.optimize_npu_model(*args, **optimize_kwargs)
 | 
			
		||||
            else:
 | 
			
		||||
                from ipex_llm.transformers.npu_models.convert import optimize_llm
 | 
			
		||||
                optimize_llm(model)
 | 
			
		||||
                with torch.no_grad():
 | 
			
		||||
                    cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
 | 
			
		||||
| 
						 | 
				
			
			@ -258,7 +259,6 @@ class _BaseAutoModelClass:
 | 
			
		|||
    def optimize_npu_model(cls, *args, **kwargs):
 | 
			
		||||
 | 
			
		||||
        from ipex_llm.transformers.npu_models.convert_mp import optimize_llm_pre, optimize_llm
 | 
			
		||||
        from intel_npu_acceleration_library.compiler import create_npu_kernels
 | 
			
		||||
 | 
			
		||||
        model = kwargs.pop("model")
 | 
			
		||||
        qtype = kwargs.pop("qtype", "sym_int4_rtn")
 | 
			
		||||
| 
						 | 
				
			
			@ -275,6 +275,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        save_directory = kwargs.pop('save_directory', None)
 | 
			
		||||
        fuse_layers = kwargs.pop('fuse_layers', None)
 | 
			
		||||
        imatrix_data = kwargs.pop('imatrix_data', None)
 | 
			
		||||
        skip_npu_logic = kwargs.pop("skip_npu_logic", False)
 | 
			
		||||
        invalidInputError(save_directory is not None,
 | 
			
		||||
                          "Please provide the path to save converted model "
 | 
			
		||||
                          "through `save_directory`.")
 | 
			
		||||
| 
						 | 
				
			
			@ -294,51 +295,58 @@ class _BaseAutoModelClass:
 | 
			
		|||
            cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
 | 
			
		||||
                             quantization_group_size, imatrix_data,
 | 
			
		||||
                             *args, **kwargs)
 | 
			
		||||
            create_npu_kernels(llm)
 | 
			
		||||
            if not skip_npu_logic:
 | 
			
		||||
                from intel_npu_acceleration_library.compiler import create_npu_kernels
 | 
			
		||||
                create_npu_kernels(llm)
 | 
			
		||||
        model = model.eval()
 | 
			
		||||
        logger.info(f"Finish to convert model")
 | 
			
		||||
        model.config.update({"bigdl_transformers_low_bit": qtype})
 | 
			
		||||
        model.share_memory()
 | 
			
		||||
 | 
			
		||||
        if not pipeline:
 | 
			
		||||
            if model.config.model_type in ["qwen2", "llama", "minicpm"]:
 | 
			
		||||
                from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
 | 
			
		||||
                optimize_llm_single_process(
 | 
			
		||||
                    llm,
 | 
			
		||||
                    kv_len=max_context_len,
 | 
			
		||||
                    max_prompt_len=max_prompt_len,
 | 
			
		||||
                    transpose_value_cache=transpose_value_cache,
 | 
			
		||||
                    group_size=quantization_group_size,
 | 
			
		||||
                    qtype=qtype,
 | 
			
		||||
                    save_directory=save_directory,
 | 
			
		||||
                    fuse_layers=fuse_layers,
 | 
			
		||||
                    has_llm=hasattr(model, "llm")
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                optimize_llm(
 | 
			
		||||
                    llm,
 | 
			
		||||
                    max_context_len=max_context_len,
 | 
			
		||||
                    max_prompt_len=max_prompt_len,
 | 
			
		||||
                    inter_pp=inter_pp,
 | 
			
		||||
                    intra_pp=intra_pp,
 | 
			
		||||
                    transpose_value_cache=transpose_value_cache,
 | 
			
		||||
                    group_size=quantization_group_size
 | 
			
		||||
                )
 | 
			
		||||
        if skip_npu_logic:
 | 
			
		||||
            model.save_low_bit(model_dir=save_directory)
 | 
			
		||||
        else:
 | 
			
		||||
            from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
 | 
			
		||||
                import convert_llm
 | 
			
		||||
            convert_llm(llm,
 | 
			
		||||
            model.share_memory()
 | 
			
		||||
 | 
			
		||||
            if not pipeline:
 | 
			
		||||
                if model.config.model_type in ["qwen2", "llama", "minicpm"]:
 | 
			
		||||
                    from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
 | 
			
		||||
                    optimize_llm_single_process(
 | 
			
		||||
                        llm,
 | 
			
		||||
                        kv_len=max_context_len,
 | 
			
		||||
                        max_prompt_len=max_prompt_len,
 | 
			
		||||
                        transpose_value_cache=transpose_value_cache,
 | 
			
		||||
                        group_size=quantization_group_size,
 | 
			
		||||
                        qtype=qtype,
 | 
			
		||||
                        convert_model=convert_model,
 | 
			
		||||
                        save_directory=save_directory,
 | 
			
		||||
                        fuse_layers=fuse_layers)
 | 
			
		||||
        model.save_low_bit = types.MethodType(save_low_bit, model)
 | 
			
		||||
        model.save_low_bit(save_directory)
 | 
			
		||||
        logger.info(f"Converted model has already saved to {save_directory}.")
 | 
			
		||||
                        fuse_layers=fuse_layers,
 | 
			
		||||
                        has_llm=hasattr(model, "llm")
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    optimize_llm(
 | 
			
		||||
                        llm,
 | 
			
		||||
                        max_context_len=max_context_len,
 | 
			
		||||
                        max_prompt_len=max_prompt_len,
 | 
			
		||||
                        inter_pp=inter_pp,
 | 
			
		||||
                        intra_pp=intra_pp,
 | 
			
		||||
                        transpose_value_cache=transpose_value_cache,
 | 
			
		||||
                        group_size=quantization_group_size
 | 
			
		||||
                    )
 | 
			
		||||
            else:
 | 
			
		||||
                from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
 | 
			
		||||
                    import convert_llm
 | 
			
		||||
                convert_llm(llm,
 | 
			
		||||
                            kv_len=max_context_len,
 | 
			
		||||
                            max_prompt_len=max_prompt_len,
 | 
			
		||||
                            transpose_value_cache=transpose_value_cache,
 | 
			
		||||
                            group_size=quantization_group_size,
 | 
			
		||||
                            qtype=qtype,
 | 
			
		||||
                            convert_model=convert_model,
 | 
			
		||||
                            save_directory=save_directory,
 | 
			
		||||
                            fuse_layers=fuse_layers)
 | 
			
		||||
            model.save_low_bit = types.MethodType(save_low_bit, model)
 | 
			
		||||
            model.save_low_bit(save_directory)
 | 
			
		||||
            logger.info(f"Converted model has already saved to {save_directory}.")
 | 
			
		||||
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
| 
						 | 
				
			
			@ -379,6 +387,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        intra_pp = kwargs.pop("intra_pp", None)
 | 
			
		||||
        transpose_value_cache = kwargs.pop("transpose_value_cache", True)
 | 
			
		||||
        modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
 | 
			
		||||
        save_directory = kwargs.pop('save_directory', None)
 | 
			
		||||
 | 
			
		||||
        from transformers.models.auto.configuration_auto import AutoConfig
 | 
			
		||||
        from transformers.modeling_utils import no_init_weights, get_state_dict_dtype
 | 
			
		||||
| 
						 | 
				
			
			@ -650,16 +659,37 @@ class _BaseAutoModelClass:
 | 
			
		|||
            param.requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
        if optimize_model and not pipeline:
 | 
			
		||||
            from ipex_llm.transformers.npu_models.convert_mp import optimize_llm
 | 
			
		||||
            optimize_llm(
 | 
			
		||||
                llm,
 | 
			
		||||
                max_context_len=max_context_len,
 | 
			
		||||
                max_prompt_len=max_prompt_len,
 | 
			
		||||
                inter_pp=inter_pp,
 | 
			
		||||
                intra_pp=intra_pp,
 | 
			
		||||
                transpose_value_cache=transpose_value_cache,
 | 
			
		||||
                group_size=quantization_group_size
 | 
			
		||||
            )
 | 
			
		||||
            if model.config.model_type in ["qwen2", "llama", "minicpm"]:
 | 
			
		||||
                from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
 | 
			
		||||
                if save_directory is None:
 | 
			
		||||
                    invalidInputError(False,
 | 
			
		||||
                                      "Please specify the save_directory, the path of folder " +
 | 
			
		||||
                                      "to save the compiled NPU model. If path not exists, " +
 | 
			
		||||
                                      "the compiled NPU model will be saved there. " +
 | 
			
		||||
                                      "Else, program will exit.")
 | 
			
		||||
 | 
			
		||||
                optimize_llm_single_process(
 | 
			
		||||
                    llm,
 | 
			
		||||
                    kv_len=max_context_len,
 | 
			
		||||
                    max_prompt_len=max_prompt_len,
 | 
			
		||||
                    transpose_value_cache=transpose_value_cache,
 | 
			
		||||
                    group_size=quantization_group_size,
 | 
			
		||||
                    qtype=qtype,
 | 
			
		||||
                    save_directory=save_directory,
 | 
			
		||||
                    fuse_layers=None,
 | 
			
		||||
                    has_llm=hasattr(model, "llm")
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                from ipex_llm.transformers.npu_models.convert_mp import optimize_llm
 | 
			
		||||
                optimize_llm(
 | 
			
		||||
                    llm,
 | 
			
		||||
                    max_context_len=max_context_len,
 | 
			
		||||
                    max_prompt_len=max_prompt_len,
 | 
			
		||||
                    inter_pp=inter_pp,
 | 
			
		||||
                    intra_pp=intra_pp,
 | 
			
		||||
                    transpose_value_cache=transpose_value_cache,
 | 
			
		||||
                    group_size=quantization_group_size
 | 
			
		||||
                )
 | 
			
		||||
        elif optimize_model and pipeline:
 | 
			
		||||
            from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
 | 
			
		||||
                import convert_llm
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,7 +18,7 @@ import torch
 | 
			
		|||
import importlib
 | 
			
		||||
import numpy as np
 | 
			
		||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params
 | 
			
		||||
from ipex_llm.transformers.npu_models.lm_head import LMHeadLinear, SlicedLMHead
 | 
			
		||||
from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,16 +21,25 @@
 | 
			
		|||
# SPDX-License-Identifier: Apache 2.0
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
 | 
			
		||||
from intel_npu_acceleration_library.dtypes import NPUDtype
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import torch
 | 
			
		||||
from torch.nn import Parameter
 | 
			
		||||
import uuid
 | 
			
		||||
import math
 | 
			
		||||
from intel_npu_acceleration_library.backend import run_matmul
 | 
			
		||||
from typing import Optional, Union
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
import importlib
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_acclib_available():
 | 
			
		||||
    return importlib.util.find_spec("intel_npu_acceleration_library") is not None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_acclib_available():
 | 
			
		||||
    from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
 | 
			
		||||
    from intel_npu_acceleration_library.dtypes import NPUDtype
 | 
			
		||||
    from intel_npu_acceleration_library.backend import run_matmul
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Linear(torch.nn.Module):
 | 
			
		||||
| 
						 | 
				
			
			@ -63,6 +72,7 @@ class Linear(torch.nn.Module):
 | 
			
		|||
        if self.training:
 | 
			
		||||
            out = self._mm(x, self.weight, None)
 | 
			
		||||
        else:
 | 
			
		||||
            from intel_npu_acceleration_library.backend import run_matmul
 | 
			
		||||
            out = run_matmul(x, self.weight, None, self.op_id)
 | 
			
		||||
 | 
			
		||||
        if self.bias is None:
 | 
			
		||||
| 
						 | 
				
			
			@ -105,6 +115,8 @@ class Linear(torch.nn.Module):
 | 
			
		|||
        Returns:
 | 
			
		||||
            Union[Linear, QuantizedLinear]: A NPU linear layer
 | 
			
		||||
        """
 | 
			
		||||
        from intel_npu_acceleration_library.quantization import quantize_tensor, compress_to_i4
 | 
			
		||||
        from intel_npu_acceleration_library.dtypes import NPUDtype
 | 
			
		||||
        if dtype.is_floating_point:
 | 
			
		||||
            if bias is None:
 | 
			
		||||
                return Linear(weight.to(dtype), None)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,96 +16,6 @@
 | 
			
		|||
import torch
 | 
			
		||||
from torch import nn
 | 
			
		||||
import numpy as np
 | 
			
		||||
from filelock import FileLock
 | 
			
		||||
from intel_npu_acceleration_library.backend import NNFactory
 | 
			
		||||
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LMHeadLinear(NNFactory):
 | 
			
		||||
    """Quantized Linear class for sliced lm_head, computing a matrix matrix multiplication
 | 
			
		||||
    with weights prefetching."""
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        inC: int,
 | 
			
		||||
        outC: int,
 | 
			
		||||
        batch: int,
 | 
			
		||||
        split_num: int = 2,
 | 
			
		||||
        profile: bool = False,
 | 
			
		||||
        device: str = "NPU",
 | 
			
		||||
        dtype: np.dtype = np.int8,
 | 
			
		||||
        use_split: bool = False,
 | 
			
		||||
        group_size: int = 0,
 | 
			
		||||
        asym: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        """Initialize the LMHeadLinear class.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            inC (int): input channels
 | 
			
		||||
            outC (int): output channels
 | 
			
		||||
            batch (int): batch
 | 
			
		||||
            split_num (int): split in_features of lm_head to how many parts
 | 
			
		||||
            profile (bool): Enable/Disable profiling. Defaults to False.
 | 
			
		||||
            device (str): Target device, default to "NPU".
 | 
			
		||||
            dtype (np.dtype): weights datatype. Defaults to np.int8.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        super().__init__(profile, device)
 | 
			
		||||
        self.inC, self.outC = inC, outC
 | 
			
		||||
        self.batch = batch
 | 
			
		||||
 | 
			
		||||
        self.split_num = split_num
 | 
			
		||||
        if use_split:
 | 
			
		||||
            input = self.parameter((1, self.batch, self.inC))
 | 
			
		||||
            res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype,
 | 
			
		||||
                                       scale_factor=(group_size == 0), asym=asym)
 | 
			
		||||
        else:
 | 
			
		||||
            input = self.parameter((self.batch, self.inC))
 | 
			
		||||
            split_size = self.inC // split_num // 2 * 2
 | 
			
		||||
 | 
			
		||||
            for i in range(self.split_num):
 | 
			
		||||
                start_idx = i * split_size
 | 
			
		||||
                end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
 | 
			
		||||
                input_slice = self.slice(input, begin=[0, start_idx],
 | 
			
		||||
                                         end=[self.batch, end_idx])
 | 
			
		||||
                linear_slice = self.linear(input_slice, outC, split_size, bias=False,
 | 
			
		||||
                                           wt_dtype=dtype, asym=asym)
 | 
			
		||||
                if i == 0:
 | 
			
		||||
                    res = linear_slice
 | 
			
		||||
                else:
 | 
			
		||||
                    res += linear_slice
 | 
			
		||||
 | 
			
		||||
        print("start compiling lm_head")
 | 
			
		||||
        self.compile()
 | 
			
		||||
        print("end compiling lm_head")
 | 
			
		||||
 | 
			
		||||
    def set_weights(self, op_id, weights):
 | 
			
		||||
        self.set_weights_async(op_id, weights)
 | 
			
		||||
        with FileLock(f"lmhead_run.lock"):
 | 
			
		||||
            backend_lib.run(self._mm)
 | 
			
		||||
 | 
			
		||||
    def set_weights_async(self, op_id, weights):
 | 
			
		||||
        self.setWeights(1, op_id, *weights)
 | 
			
		||||
 | 
			
		||||
    def run(
 | 
			
		||||
        self, X: np.ndarray
 | 
			
		||||
    ) -> np.ndarray:
 | 
			
		||||
        """Run the layer:  $X * (W * S)^T$ .
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            X (np.ndarray): activation
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            RuntimeError: Input, weights or scale shape mismatch
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            np.ndarray: result
 | 
			
		||||
        """
 | 
			
		||||
        self.set_input_tensor(X, 0)
 | 
			
		||||
        self.elapsed = backend_lib.run(self._mm)
 | 
			
		||||
        if len(self.out) == 1:
 | 
			
		||||
            return self.out[0]
 | 
			
		||||
        return self.out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SlicedLMHead(nn.Module):
 | 
			
		||||
| 
						 | 
				
			
			@ -160,6 +70,7 @@ class SlicedLMHead(nn.Module):
 | 
			
		|||
        return self.lm_heads[0].weight.dtype
 | 
			
		||||
 | 
			
		||||
    def get_fused_lm_head(self):
 | 
			
		||||
        from ipex_llm.transformers.npu_models.lm_head_linear import LMHeadLinear
 | 
			
		||||
        np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8
 | 
			
		||||
        self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num,
 | 
			
		||||
                                          False, "NPU", dtype=np_dtype, use_split=self.use_split,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,106 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from filelock import FileLock
 | 
			
		||||
from intel_npu_acceleration_library.backend import NNFactory
 | 
			
		||||
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LMHeadLinear(NNFactory):
 | 
			
		||||
    """Quantized Linear class for sliced lm_head, computing a matrix matrix multiplication
 | 
			
		||||
    with weights prefetching."""
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        inC: int,
 | 
			
		||||
        outC: int,
 | 
			
		||||
        batch: int,
 | 
			
		||||
        split_num: int = 2,
 | 
			
		||||
        profile: bool = False,
 | 
			
		||||
        device: str = "NPU",
 | 
			
		||||
        dtype: np.dtype = np.int8,
 | 
			
		||||
        use_split: bool = False,
 | 
			
		||||
        group_size: int = 0,
 | 
			
		||||
        asym: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        """Initialize the LMHeadLinear class.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            inC (int): input channels
 | 
			
		||||
            outC (int): output channels
 | 
			
		||||
            batch (int): batch
 | 
			
		||||
            split_num (int): split in_features of lm_head to how many parts
 | 
			
		||||
            profile (bool): Enable/Disable profiling. Defaults to False.
 | 
			
		||||
            device (str): Target device, default to "NPU".
 | 
			
		||||
            dtype (np.dtype): weights datatype. Defaults to np.int8.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        super().__init__(profile, device)
 | 
			
		||||
        self.inC, self.outC = inC, outC
 | 
			
		||||
        self.batch = batch
 | 
			
		||||
 | 
			
		||||
        self.split_num = split_num
 | 
			
		||||
        if use_split:
 | 
			
		||||
            input = self.parameter((1, self.batch, self.inC))
 | 
			
		||||
            res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype,
 | 
			
		||||
                                       scale_factor=(group_size == 0), asym=asym)
 | 
			
		||||
        else:
 | 
			
		||||
            input = self.parameter((self.batch, self.inC))
 | 
			
		||||
            split_size = self.inC // split_num // 2 * 2
 | 
			
		||||
 | 
			
		||||
            for i in range(self.split_num):
 | 
			
		||||
                start_idx = i * split_size
 | 
			
		||||
                end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
 | 
			
		||||
                input_slice = self.slice(input, begin=[0, start_idx],
 | 
			
		||||
                                         end=[self.batch, end_idx])
 | 
			
		||||
                linear_slice = self.linear(input_slice, outC, split_size, bias=False,
 | 
			
		||||
                                           wt_dtype=dtype, asym=asym)
 | 
			
		||||
                if i == 0:
 | 
			
		||||
                    res = linear_slice
 | 
			
		||||
                else:
 | 
			
		||||
                    res += linear_slice
 | 
			
		||||
 | 
			
		||||
        print("start compiling lm_head")
 | 
			
		||||
        self.compile()
 | 
			
		||||
        print("end compiling lm_head")
 | 
			
		||||
 | 
			
		||||
    def set_weights(self, op_id, weights):
 | 
			
		||||
        self.set_weights_async(op_id, weights)
 | 
			
		||||
        with FileLock(f"lmhead_run.lock"):
 | 
			
		||||
            backend_lib.run(self._mm)
 | 
			
		||||
 | 
			
		||||
    def set_weights_async(self, op_id, weights):
 | 
			
		||||
        self.setWeights(1, op_id, *weights)
 | 
			
		||||
 | 
			
		||||
    def run(
 | 
			
		||||
        self, X: np.ndarray
 | 
			
		||||
    ) -> np.ndarray:
 | 
			
		||||
        """Run the layer:  $X * (W * S)^T$ .
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            X (np.ndarray): activation
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            RuntimeError: Input, weights or scale shape mismatch
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            np.ndarray: result
 | 
			
		||||
        """
 | 
			
		||||
        self.set_input_tensor(X, 0)
 | 
			
		||||
        self.elapsed = backend_lib.run(self._mm)
 | 
			
		||||
        if len(self.out) == 1:
 | 
			
		||||
            return self.out[0]
 | 
			
		||||
        return self.out
 | 
			
		||||
		Loading…
	
		Reference in a new issue