LLM: 2bit quantization initial support (#10042)
* basis quantize support * fix new module name * small update * and mixed int4 with iq2_xxs * remove print * code refactor * fix style * meet code review
This commit is contained in:
		
							parent
							
								
									f440cb4fba
								
							
						
					
					
						commit
						d61f4905ac
					
				
					 6 changed files with 164 additions and 22 deletions
				
			
		| 
						 | 
				
			
			@ -965,6 +965,30 @@ _lib.ggml_quantize_tensor.argtypes = [
 | 
			
		|||
_lib.ggml_quantize_tensor.restype = ctypes.c_size_t
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ggml_quantize_tensor_with_weights(
 | 
			
		||||
    src,  # type: ctypes.Array[ctypes.c_float] # type: ignore
 | 
			
		||||
    dst: ctypes.c_void_p,
 | 
			
		||||
    qtype: ctypes.c_int,
 | 
			
		||||
    nrow: ctypes.c_int,
 | 
			
		||||
    n_per_row: ctypes.c_int,
 | 
			
		||||
    hist,  # type: ctypes.Array[ctypes.c_int64] # type: ignore
 | 
			
		||||
    weights,  # type: ctypes.Array[ctypes.c_float] # type: ignore
 | 
			
		||||
) -> int:
 | 
			
		||||
    return _lib.ggml_quantize_tensor_with_weights(src, dst, qtype, nrow, n_per_row, hist, weights)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_lib.ggml_quantize_tensor_with_weights.argtypes = [
 | 
			
		||||
    ctypes.POINTER(ctypes.c_float),
 | 
			
		||||
    ctypes.c_void_p,
 | 
			
		||||
    ctypes.c_int,
 | 
			
		||||
    ctypes.c_int,
 | 
			
		||||
    ctypes.c_int,
 | 
			
		||||
    ctypes.POINTER(ctypes.c_int64),
 | 
			
		||||
    ctypes.POINTER(ctypes.c_float),
 | 
			
		||||
]
 | 
			
		||||
_lib.ggml_quantize_tensor_with_weights.restype = ctypes.c_size_t
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ggml_type_size(qtype: ctypes.c_int) -> int:
 | 
			
		||||
    return _lib.ggml_type_size(qtype)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -39,7 +39,9 @@ ggml_tensor_qtype = {"sym_int4": 2,   # q4_0 in ggml
 | 
			
		|||
                     "mixed_fp8": 18,     # Mixture of Formats Quantization 8 bits
 | 
			
		||||
                     "fp8_e5m2": 19,      # fp8 in e5m2 format
 | 
			
		||||
                     "fp8": 19,           # fp8 in e5m2 format
 | 
			
		||||
                     "bf16": 20}
 | 
			
		||||
                     "bf16": 20,
 | 
			
		||||
                     "iq2_xxs": 21,
 | 
			
		||||
                     "iq2_xs": 22}
 | 
			
		||||
 | 
			
		||||
_llama_quantize_type = {"q4_0": 2,
 | 
			
		||||
                        "q4_1": 3,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,7 +43,7 @@ import warnings
 | 
			
		|||
import transformers
 | 
			
		||||
import importlib.util
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from .utils import logger
 | 
			
		||||
from .utils import logger, get_cur_qtype_and_imatrix
 | 
			
		||||
from typing import Union
 | 
			
		||||
import numpy as np
 | 
			
		||||
import os
 | 
			
		||||
| 
						 | 
				
			
			@ -190,7 +190,8 @@ def convert_gptq(module, awq=False, llm_awq=False):
 | 
			
		|||
 | 
			
		||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		||||
                                 current_key_name=None, convert_shape_only=False,
 | 
			
		||||
                                 cpu_embedding=False, prefix_name=''):
 | 
			
		||||
                                 cpu_embedding=False, prefix_name='',
 | 
			
		||||
                                 imatrix_data=None):
 | 
			
		||||
    from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
 | 
			
		||||
        FP16Linear, BF16Linear
 | 
			
		||||
    from bigdl.llm.transformers.embedding import LLMEmbedding
 | 
			
		||||
| 
						 | 
				
			
			@ -248,7 +249,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                            module.bias is not None,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                        cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
 | 
			
		||||
                                                                           full_module_name,
 | 
			
		||||
                                                                           imatrix_data)
 | 
			
		||||
                        device = module.weight.data.device
 | 
			
		||||
                        # Copy the weights
 | 
			
		||||
                        paramsLowBit = FP4Params(data=module.weight.data,
 | 
			
		||||
| 
						 | 
				
			
			@ -256,7 +259,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                                                 quantized=False,
 | 
			
		||||
                                                 _shape=None,
 | 
			
		||||
                                                 convert_shape_only=convert_shape_only,
 | 
			
		||||
                                                 qtype=qtype).to(device)
 | 
			
		||||
                                                 qtype=cur_qtype,
 | 
			
		||||
                                                 imatrix=cur_imatrix,
 | 
			
		||||
                                                 in_features=in_features).to(device)
 | 
			
		||||
                        new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                        if module.bias is not None:
 | 
			
		||||
                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
| 
						 | 
				
			
			@ -328,7 +333,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                current_key_name,
 | 
			
		||||
                convert_shape_only,
 | 
			
		||||
                cpu_embedding,
 | 
			
		||||
                prefix_name=prefix_name + '.' + name if prefix_name != '' else name
 | 
			
		||||
                prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
 | 
			
		||||
                imatrix_data=imatrix_data
 | 
			
		||||
            )
 | 
			
		||||
            has_been_replaced = _flag or has_been_replaced
 | 
			
		||||
    return model, has_been_replaced
 | 
			
		||||
| 
						 | 
				
			
			@ -505,7 +511,8 @@ def _optimize_pre(model):
 | 
			
		|||
def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		||||
                         convert_shape_only=False, device="cpu",
 | 
			
		||||
                         modules_to_not_convert=None, cpu_embedding=False,
 | 
			
		||||
                         lightweight_bmm=False, torch_dtype="auto"):
 | 
			
		||||
                         lightweight_bmm=False, torch_dtype="auto",
 | 
			
		||||
                         imatrix_data=None):
 | 
			
		||||
    logger.info(f"Converting the current model to "
 | 
			
		||||
                f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
 | 
			
		||||
                f"format......")
 | 
			
		||||
| 
						 | 
				
			
			@ -517,6 +524,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
			
		||||
        model, qtype, modules_to_not_convert,
 | 
			
		||||
        None, convert_shape_only, cpu_embedding,
 | 
			
		||||
        imatrix_data=imatrix_data,
 | 
			
		||||
    )
 | 
			
		||||
    if not has_been_replaced:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -70,6 +70,8 @@ FP4 = ggml_tensor_qtype["fp4"]
 | 
			
		|||
MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
 | 
			
		||||
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
 | 
			
		||||
FP8E5 = ggml_tensor_qtype["fp8_e5m2"]
 | 
			
		||||
IQ2_XXS = ggml_tensor_qtype["iq2_xxs"]
 | 
			
		||||
IQ2_XS = ggml_tensor_qtype["iq2_xs"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_block_size(qtype: str):
 | 
			
		||||
| 
						 | 
				
			
			@ -81,7 +83,9 @@ def get_qk_size(qtype: int):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
 | 
			
		||||
                       device=None, convert_shape_only=False):
 | 
			
		||||
                       device=None, convert_shape_only=False,
 | 
			
		||||
                       imatrix: torch.Tensor=None,
 | 
			
		||||
                       in_features: int=None):
 | 
			
		||||
    QK = ggml.ggml_qk_size(qtype)
 | 
			
		||||
    block_size_in_bytes = ggml.ggml_type_size(qtype)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -89,12 +93,10 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
 | 
			
		|||
                      "Input tensor must be float32")
 | 
			
		||||
    src = tensor.data.data_ptr()
 | 
			
		||||
    src = ctypes.cast(src, ctypes.POINTER(ctypes.c_float))
 | 
			
		||||
    n = tensor.numel()
 | 
			
		||||
    invalidInputError(n % QK == 0,
 | 
			
		||||
                      "Input tensor size must be multiple of 64")
 | 
			
		||||
    n = tensor.numel()  # all elements
 | 
			
		||||
    k = tensor.shape[-1]
 | 
			
		||||
    invalidInputError(k % QK == 0,
 | 
			
		||||
                      "Last dim of input tensor must be multiple of 64")
 | 
			
		||||
                      f"Last dim of input tensor must be multiple of {QK}")
 | 
			
		||||
 | 
			
		||||
    dst_size = (n // QK) * block_size_in_bytes
 | 
			
		||||
    dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
 | 
			
		||||
| 
						 | 
				
			
			@ -103,7 +105,16 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
 | 
			
		|||
    if not convert_shape_only and device != 'meta':
 | 
			
		||||
        dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
 | 
			
		||||
        hist = (ctypes.c_int64 * 16)()
 | 
			
		||||
        ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
 | 
			
		||||
        if qtype not in [IQ2_XXS, IQ2_XS]:
 | 
			
		||||
            ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
 | 
			
		||||
        else:
 | 
			
		||||
            # quantize with importance matrix
 | 
			
		||||
            imatrix = imatrix.data.data_ptr()
 | 
			
		||||
            imatrix = ctypes.cast(imatrix, ctypes.POINTER(ctypes.c_float))
 | 
			
		||||
            # pass nrow and n_per_row
 | 
			
		||||
            ggml.ggml_quantize_tensor_with_weights(src, dst, qtype,
 | 
			
		||||
                                                   n // in_features, in_features,
 | 
			
		||||
                                                   hist, imatrix)
 | 
			
		||||
    return dst_tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -193,7 +204,9 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
                quantized=False,
 | 
			
		||||
                _shape=None,
 | 
			
		||||
                convert_shape_only=False,
 | 
			
		||||
                qtype=None):
 | 
			
		||||
                qtype=None,
 | 
			
		||||
                imatrix=None,
 | 
			
		||||
                in_features=None):
 | 
			
		||||
        if data is None:
 | 
			
		||||
            data = torch.empty(0)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -203,6 +216,8 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
        self._shape = _shape
 | 
			
		||||
        self.qtype = qtype
 | 
			
		||||
        self.convert_shape_only = convert_shape_only
 | 
			
		||||
        self.imatrix = imatrix
 | 
			
		||||
        self.in_features = in_features
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def ggml_mse(self, w, ggml_qtype, device):
 | 
			
		||||
| 
						 | 
				
			
			@ -255,7 +270,9 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
            else:
 | 
			
		||||
                w_quantized = ggml_convert_qtype(w, self.qtype,
 | 
			
		||||
                                                 device=device,
 | 
			
		||||
                                                 convert_shape_only=self.convert_shape_only)
 | 
			
		||||
                                                 convert_shape_only=self.convert_shape_only,
 | 
			
		||||
                                                 imatrix=self.imatrix,
 | 
			
		||||
                                                 in_features=self.in_features)
 | 
			
		||||
                self.data = w_quantized
 | 
			
		||||
            self.quantized = True
 | 
			
		||||
            self._shape = w.shape
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -41,7 +41,7 @@ import transformers
 | 
			
		|||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
from .utils import extract_local_archive_file, \
 | 
			
		||||
    load_state_dict, \
 | 
			
		||||
    get_local_shard_files
 | 
			
		||||
    get_local_shard_files, load_imatrix_data
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.transformers.gguf.api import load_gguf_model
 | 
			
		||||
| 
						 | 
				
			
			@ -107,10 +107,10 @@ class _BaseAutoModelClass:
 | 
			
		|||
        :param load_in_low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``,
 | 
			
		||||
                                ``'sym_int5'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``,
 | 
			
		||||
                                ``'nf4'``, ``'fp4'``, ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``,
 | 
			
		||||
                                ``'fp16'`` or ``'bf16'``, ``'sym_int4'`` means symmetric int 4,
 | 
			
		||||
                                ``'asym_int4'`` means asymmetric int 4, ``'nf4'`` means 4-bit
 | 
			
		||||
                                NormalFloat, etc. Relevant low bit optimizations will be applied
 | 
			
		||||
                                to the model.
 | 
			
		||||
                                ``'iq2_xxs'``, ``'iq2_xs'``, ``'fp16'`` or ``'bf16'``,
 | 
			
		||||
                                ``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means
 | 
			
		||||
                                asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc.
 | 
			
		||||
                                Relevant low bit optimizations will be applied to the model.
 | 
			
		||||
        :param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
 | 
			
		||||
                               Default to be ``True``.
 | 
			
		||||
        :param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when
 | 
			
		||||
| 
						 | 
				
			
			@ -121,6 +121,9 @@ class _BaseAutoModelClass:
 | 
			
		|||
            to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
 | 
			
		||||
        :param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it
 | 
			
		||||
            to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
 | 
			
		||||
        :param imatrix: str value, represent filename of importance matrix pretrained on
 | 
			
		||||
            specific datasets for use with the improved quantization methods recently
 | 
			
		||||
            added to llama.cpp.
 | 
			
		||||
        :return: a model instance
 | 
			
		||||
        """
 | 
			
		||||
        pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
 | 
			
		||||
| 
						 | 
				
			
			@ -243,6 +246,12 @@ class _BaseAutoModelClass:
 | 
			
		|||
                else:
 | 
			
		||||
                    kwargs["pretraining_tp"] = 1
 | 
			
		||||
            q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
 | 
			
		||||
            if q_k in ["iq2_xxs", "iq2_xs"]:
 | 
			
		||||
                imatrix_file = kwargs.pop("imatrix", None)
 | 
			
		||||
                invalidInputError(imatrix_file is not None,
 | 
			
		||||
                                  "For iq2_xxs and iq2_xs quantization, imatrix is needed.")
 | 
			
		||||
                imatrix_data = load_imatrix_data(imatrix_file)
 | 
			
		||||
                kwargs['imatrix_data'] = imatrix_data
 | 
			
		||||
            model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
            if speculative:
 | 
			
		||||
| 
						 | 
				
			
			@ -285,7 +294,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
        invalidInputError(q_k in ggml_tensor_qtype,
 | 
			
		||||
                          f"Unknown load_in_low_bit value: {q_k}, expected:"
 | 
			
		||||
                          f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
 | 
			
		||||
                          "fp4, fp8, fp8_e4m3, fp8_e5m2, fp16,  bf16, mixed_fp4 or mixed_fp8.")
 | 
			
		||||
                          f"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16,  bf16, iq2_xxs, iq2_xs, "
 | 
			
		||||
                          f"mixed_fp4 or mixed_fp8.")
 | 
			
		||||
        qtype = ggml_tensor_qtype[q_k]
 | 
			
		||||
 | 
			
		||||
        # In case it needs a second try,
 | 
			
		||||
| 
						 | 
				
			
			@ -299,6 +309,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
            cpu_embedding = True
 | 
			
		||||
        lightweight_bmm = kwargs.pop("lightweight_bmm", False)
 | 
			
		||||
        quant_config = kwargs.pop("quantization_config", None)
 | 
			
		||||
        imatrix_data = kwargs.pop("imatrix_data", None)
 | 
			
		||||
        _args = copy.deepcopy(args)
 | 
			
		||||
        _kwargs = copy.deepcopy(kwargs)
 | 
			
		||||
        awq_config = None
 | 
			
		||||
| 
						 | 
				
			
			@ -359,7 +370,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
        model = ggml_convert_low_bit(model, qtype, optimize_model,
 | 
			
		||||
                                     modules_to_not_convert=modules_to_not_convert,
 | 
			
		||||
                                     cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
 | 
			
		||||
                                     torch_dtype=kwargs.get("torch_dtype", 'auto'))
 | 
			
		||||
                                     torch_dtype=kwargs.get("torch_dtype", 'auto'),
 | 
			
		||||
                                     imatrix_data=imatrix_data)
 | 
			
		||||
        model.config.update({"bigdl_transformers_low_bit": q_k})
 | 
			
		||||
 | 
			
		||||
        # enable tie_word_embeddings for MPT
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -41,11 +41,14 @@
 | 
			
		|||
# SOFTWARE.
 | 
			
		||||
import os
 | 
			
		||||
from transformers.modeling_utils import _add_variant
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from ..utils.common import invalidInputError
 | 
			
		||||
from typing import Union
 | 
			
		||||
import torch
 | 
			
		||||
from torch import nn
 | 
			
		||||
import logging
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
| 
						 | 
				
			
			@ -179,3 +182,79 @@ def get_xpu_device_type(x):
 | 
			
		|||
        return "pvc"
 | 
			
		||||
    else:
 | 
			
		||||
        return "others"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_imatrix_data(imatrix_file):
 | 
			
		||||
    # this function is adapted from https://github.com/ggerganov/llama.cpp/blob/
 | 
			
		||||
    # c82d18e863fcde91b4b1109b1d0c73ea4470c405/examples/quantize/quantize.cpp#L102
 | 
			
		||||
    imatrix = open(imatrix_file, 'rb')
 | 
			
		||||
    n_entries = imatrix.read(4)
 | 
			
		||||
    n_entries = int.from_bytes(n_entries, 'little')
 | 
			
		||||
    invalidInputError(n_entries >= 1,
 | 
			
		||||
                      f"failed reading name for entry from {imatrix_file}")
 | 
			
		||||
    imatrix_data = {}
 | 
			
		||||
    for i in range(n_entries):
 | 
			
		||||
        cur_len = imatrix.read(4)
 | 
			
		||||
        cur_len = int.from_bytes(cur_len, 'little')
 | 
			
		||||
        cur_name = str(imatrix.read(cur_len), encoding='utf-8')
 | 
			
		||||
        # original cur_name looks like blk.14.attn_output.weight for llama
 | 
			
		||||
        # TODO: how to better aligned and generalize
 | 
			
		||||
        name_list = cur_name.split('.')
 | 
			
		||||
        layer = name_list[1]
 | 
			
		||||
        module_name = name_list[2]
 | 
			
		||||
        if 'ffn' in module_name:
 | 
			
		||||
            module_name = module_name[4:]  # from ffn_gate to gate
 | 
			
		||||
        elif 'attn' in module_name:
 | 
			
		||||
            module_name = module_name[5]  # from attn_k to k, attn_output to o
 | 
			
		||||
        module_name = layer + '_' + module_name
 | 
			
		||||
        ncall = imatrix.read(4)
 | 
			
		||||
        ncall = int.from_bytes(ncall, 'little')
 | 
			
		||||
        nval = imatrix.read(4)
 | 
			
		||||
        nval = int.from_bytes(nval, 'little')
 | 
			
		||||
        invalidInputError(nval >= 1,
 | 
			
		||||
                          f"failed reading number of values for entry {i}")
 | 
			
		||||
        byte_data = imatrix.read(4 * nval)
 | 
			
		||||
        idata = np.frombuffer(byte_data, dtype=np.float32)
 | 
			
		||||
 | 
			
		||||
        if ncall > 0:
 | 
			
		||||
            idata = idata / ncall
 | 
			
		||||
        imatrix_data[module_name] = torch.from_numpy(idata).float()
 | 
			
		||||
 | 
			
		||||
    print(f"loaded {len(imatrix_data)} importance matrix entries from {imatrix_file}.")
 | 
			
		||||
    return imatrix_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
 | 
			
		||||
    if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]:
 | 
			
		||||
        # For quantization which needs importance matrix
 | 
			
		||||
        # module name preprocess
 | 
			
		||||
        # full name maybe model.layers.31.self_attn.o_proj
 | 
			
		||||
        # TODO: just consider llama/mistral here
 | 
			
		||||
        # TODO: how to better aligned and generalize
 | 
			
		||||
        module_name = full_module_name.split('.')
 | 
			
		||||
        cur_qtype = qtype
 | 
			
		||||
        if len(module_name) == 5:
 | 
			
		||||
            layer = module_name[2]
 | 
			
		||||
            cur_module = module_name[-1][:-5]
 | 
			
		||||
            new_module_name = '_'.join([layer, cur_module])
 | 
			
		||||
        elif len(module_name) == 1:
 | 
			
		||||
            new_module_name = module_name[0]
 | 
			
		||||
            layer = None
 | 
			
		||||
            cur_module = None
 | 
			
		||||
        if imatrix_data is not None and new_module_name in imatrix_data:
 | 
			
		||||
            cur_imatrix = imatrix_data[new_module_name]
 | 
			
		||||
            # custom mixed quantization strategy
 | 
			
		||||
            if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \
 | 
			
		||||
                    or new_module_name == 'lm_head':
 | 
			
		||||
                cur_qtype = ggml_tensor_qtype['sym_int4']
 | 
			
		||||
        else:
 | 
			
		||||
            cur_imatrix = None
 | 
			
		||||
            # custom mixed quantization strategy
 | 
			
		||||
            if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \
 | 
			
		||||
                    or new_module_name == 'lm_head':
 | 
			
		||||
                cur_qtype = ggml_tensor_qtype['sym_int4']
 | 
			
		||||
    else:
 | 
			
		||||
        cur_imatrix = None
 | 
			
		||||
        cur_qtype = qtype
 | 
			
		||||
 | 
			
		||||
    return cur_qtype, cur_imatrix
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue