From d61f4905ace167f88600d3e530ec234ecc097fc5 Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Tue, 6 Feb 2024 14:58:32 +0800 Subject: [PATCH] LLM: 2bit quantization initial support (#10042) * basis quantize support * fix new module name * small update * and mixed int4 with iq2_xxs * remove print * code refactor * fix style * meet code review --- .../bigdl/llm/ggml/model/llama/llama_cpp.py | 24 ++++++ python/llm/src/bigdl/llm/ggml/quantize.py | 4 +- .../llm/src/bigdl/llm/transformers/convert.py | 20 +++-- .../bigdl/llm/transformers/low_bit_linear.py | 33 ++++++-- .../llm/src/bigdl/llm/transformers/model.py | 26 ++++-- .../llm/src/bigdl/llm/transformers/utils.py | 79 +++++++++++++++++++ 6 files changed, 164 insertions(+), 22 deletions(-) diff --git a/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py b/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py index 78f5681b..b7b5d2ed 100644 --- a/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py +++ b/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py @@ -965,6 +965,30 @@ _lib.ggml_quantize_tensor.argtypes = [ _lib.ggml_quantize_tensor.restype = ctypes.c_size_t +def ggml_quantize_tensor_with_weights( + src, # type: ctypes.Array[ctypes.c_float] # type: ignore + dst: ctypes.c_void_p, + qtype: ctypes.c_int, + nrow: ctypes.c_int, + n_per_row: ctypes.c_int, + hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore + weights, # type: ctypes.Array[ctypes.c_float] # type: ignore +) -> int: + return _lib.ggml_quantize_tensor_with_weights(src, dst, qtype, nrow, n_per_row, hist, weights) + + +_lib.ggml_quantize_tensor_with_weights.argtypes = [ + ctypes.POINTER(ctypes.c_float), + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.POINTER(ctypes.c_int64), + ctypes.POINTER(ctypes.c_float), +] +_lib.ggml_quantize_tensor_with_weights.restype = ctypes.c_size_t + + def ggml_type_size(qtype: ctypes.c_int) -> int: return _lib.ggml_type_size(qtype) diff --git a/python/llm/src/bigdl/llm/ggml/quantize.py b/python/llm/src/bigdl/llm/ggml/quantize.py index 11e04ab8..31be0f8f 100644 --- a/python/llm/src/bigdl/llm/ggml/quantize.py +++ b/python/llm/src/bigdl/llm/ggml/quantize.py @@ -39,7 +39,9 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml "mixed_fp8": 18, # Mixture of Formats Quantization 8 bits "fp8_e5m2": 19, # fp8 in e5m2 format "fp8": 19, # fp8 in e5m2 format - "bf16": 20} + "bf16": 20, + "iq2_xxs": 21, + "iq2_xs": 22} _llama_quantize_type = {"q4_0": 2, "q4_1": 3, diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 518de262..7c1aaec8 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -43,7 +43,7 @@ import warnings import transformers import importlib.util from bigdl.llm.ggml.quantize import ggml_tensor_qtype -from .utils import logger +from .utils import logger, get_cur_qtype_and_imatrix from typing import Union import numpy as np import os @@ -190,7 +190,8 @@ def convert_gptq(module, awq=False, llm_awq=False): def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, current_key_name=None, convert_shape_only=False, - cpu_embedding=False, prefix_name=''): + cpu_embedding=False, prefix_name='', + imatrix_data=None): from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \ FP16Linear, BF16Linear from bigdl.llm.transformers.embedding import LLMEmbedding @@ -248,7 +249,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, module.bias is not None, mp_group=mp_group, ) - + cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype, + full_module_name, + imatrix_data) device = module.weight.data.device # Copy the weights paramsLowBit = FP4Params(data=module.weight.data, @@ -256,7 +259,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, quantized=False, _shape=None, convert_shape_only=convert_shape_only, - qtype=qtype).to(device) + qtype=cur_qtype, + imatrix=cur_imatrix, + in_features=in_features).to(device) new_linear._parameters['weight'] = paramsLowBit if module.bias is not None: new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ @@ -328,7 +333,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, current_key_name, convert_shape_only, cpu_embedding, - prefix_name=prefix_name + '.' + name if prefix_name != '' else name + prefix_name=prefix_name + '.' + name if prefix_name != '' else name, + imatrix_data=imatrix_data ) has_been_replaced = _flag or has_been_replaced return model, has_been_replaced @@ -505,7 +511,8 @@ def _optimize_pre(model): def ggml_convert_low_bit(model, qtype, optimize_model=True, convert_shape_only=False, device="cpu", modules_to_not_convert=None, cpu_embedding=False, - lightweight_bmm=False, torch_dtype="auto"): + lightweight_bmm=False, torch_dtype="auto", + imatrix_data=None): logger.info(f"Converting the current model to " f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} " f"format......") @@ -517,6 +524,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, model, has_been_replaced = _replace_with_low_bit_linear( model, qtype, modules_to_not_convert, None, convert_shape_only, cpu_embedding, + imatrix_data=imatrix_data, ) if not has_been_replaced: warnings.warn( diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 08dbab8f..62c97da1 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -70,6 +70,8 @@ FP4 = ggml_tensor_qtype["fp4"] MOFQ4 = ggml_tensor_qtype["mixed_fp4"] MOFQ8 = ggml_tensor_qtype["mixed_fp8"] FP8E5 = ggml_tensor_qtype["fp8_e5m2"] +IQ2_XXS = ggml_tensor_qtype["iq2_xxs"] +IQ2_XS = ggml_tensor_qtype["iq2_xs"] def get_block_size(qtype: str): @@ -81,7 +83,9 @@ def get_qk_size(qtype: int): def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, - device=None, convert_shape_only=False): + device=None, convert_shape_only=False, + imatrix: torch.Tensor=None, + in_features: int=None): QK = ggml.ggml_qk_size(qtype) block_size_in_bytes = ggml.ggml_type_size(qtype) @@ -89,12 +93,10 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, "Input tensor must be float32") src = tensor.data.data_ptr() src = ctypes.cast(src, ctypes.POINTER(ctypes.c_float)) - n = tensor.numel() - invalidInputError(n % QK == 0, - "Input tensor size must be multiple of 64") + n = tensor.numel() # all elements k = tensor.shape[-1] invalidInputError(k % QK == 0, - "Last dim of input tensor must be multiple of 64") + f"Last dim of input tensor must be multiple of {QK}") dst_size = (n // QK) * block_size_in_bytes dst_tensor = torch.empty(dst_size, dtype=torch.uint8, @@ -103,7 +105,16 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, if not convert_shape_only and device != 'meta': dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) hist = (ctypes.c_int64 * 16)() - ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist) + if qtype not in [IQ2_XXS, IQ2_XS]: + ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist) + else: + # quantize with importance matrix + imatrix = imatrix.data.data_ptr() + imatrix = ctypes.cast(imatrix, ctypes.POINTER(ctypes.c_float)) + # pass nrow and n_per_row + ggml.ggml_quantize_tensor_with_weights(src, dst, qtype, + n // in_features, in_features, + hist, imatrix) return dst_tensor @@ -193,7 +204,9 @@ class FP4Params(torch.nn.Parameter): quantized=False, _shape=None, convert_shape_only=False, - qtype=None): + qtype=None, + imatrix=None, + in_features=None): if data is None: data = torch.empty(0) @@ -203,6 +216,8 @@ class FP4Params(torch.nn.Parameter): self._shape = _shape self.qtype = qtype self.convert_shape_only = convert_shape_only + self.imatrix = imatrix + self.in_features = in_features return self def ggml_mse(self, w, ggml_qtype, device): @@ -255,7 +270,9 @@ class FP4Params(torch.nn.Parameter): else: w_quantized = ggml_convert_qtype(w, self.qtype, device=device, - convert_shape_only=self.convert_shape_only) + convert_shape_only=self.convert_shape_only, + imatrix=self.imatrix, + in_features=self.in_features) self.data = w_quantized self.quantized = True self._shape = w.shape diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 1e8e431b..8b6d6b2f 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -41,7 +41,7 @@ import transformers from transformers.configuration_utils import PretrainedConfig from .utils import extract_local_archive_file, \ load_state_dict, \ - get_local_shard_files + get_local_shard_files, load_imatrix_data from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.gguf.api import load_gguf_model @@ -107,10 +107,10 @@ class _BaseAutoModelClass: :param load_in_low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``, ``'sym_int5'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``, ``'nf4'``, ``'fp4'``, ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``, - ``'fp16'`` or ``'bf16'``, ``'sym_int4'`` means symmetric int 4, - ``'asym_int4'`` means asymmetric int 4, ``'nf4'`` means 4-bit - NormalFloat, etc. Relevant low bit optimizations will be applied - to the model. + ``'iq2_xxs'``, ``'iq2_xs'``, ``'fp16'`` or ``'bf16'``, + ``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means + asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc. + Relevant low bit optimizations will be applied to the model. :param optimize_model: boolean value, Whether to further optimize the low_bit llm model. Default to be ``True``. :param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when @@ -121,6 +121,9 @@ class _BaseAutoModelClass: to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``. :param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``. + :param imatrix: str value, represent filename of importance matrix pretrained on + specific datasets for use with the improved quantization methods recently + added to llama.cpp. :return: a model instance """ pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \ @@ -243,6 +246,12 @@ class _BaseAutoModelClass: else: kwargs["pretraining_tp"] = 1 q_k = load_in_low_bit if load_in_low_bit else "sym_int4" + if q_k in ["iq2_xxs", "iq2_xs"]: + imatrix_file = kwargs.pop("imatrix", None) + invalidInputError(imatrix_file is not None, + "For iq2_xxs and iq2_xs quantization, imatrix is needed.") + imatrix_data = load_imatrix_data(imatrix_file) + kwargs['imatrix_data'] = imatrix_data model = cls.load_convert(q_k, optimize_model, *args, **kwargs) if speculative: @@ -285,7 +294,8 @@ class _BaseAutoModelClass: invalidInputError(q_k in ggml_tensor_qtype, f"Unknown load_in_low_bit value: {q_k}, expected:" f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, " - "fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, mixed_fp4 or mixed_fp8.") + f"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, iq2_xxs, iq2_xs, " + f"mixed_fp4 or mixed_fp8.") qtype = ggml_tensor_qtype[q_k] # In case it needs a second try, @@ -299,6 +309,7 @@ class _BaseAutoModelClass: cpu_embedding = True lightweight_bmm = kwargs.pop("lightweight_bmm", False) quant_config = kwargs.pop("quantization_config", None) + imatrix_data = kwargs.pop("imatrix_data", None) _args = copy.deepcopy(args) _kwargs = copy.deepcopy(kwargs) awq_config = None @@ -359,7 +370,8 @@ class _BaseAutoModelClass: model = ggml_convert_low_bit(model, qtype, optimize_model, modules_to_not_convert=modules_to_not_convert, cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm, - torch_dtype=kwargs.get("torch_dtype", 'auto')) + torch_dtype=kwargs.get("torch_dtype", 'auto'), + imatrix_data=imatrix_data) model.config.update({"bigdl_transformers_low_bit": q_k}) # enable tie_word_embeddings for MPT diff --git a/python/llm/src/bigdl/llm/transformers/utils.py b/python/llm/src/bigdl/llm/transformers/utils.py index 89c93b55..06e3a0fd 100644 --- a/python/llm/src/bigdl/llm/transformers/utils.py +++ b/python/llm/src/bigdl/llm/transformers/utils.py @@ -41,11 +41,14 @@ # SOFTWARE. import os from transformers.modeling_utils import _add_variant +from bigdl.llm.ggml.quantize import ggml_tensor_qtype from ..utils.common import invalidInputError from typing import Union import torch from torch import nn import logging +import numpy as np + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) logger = logging.getLogger(__name__) @@ -179,3 +182,79 @@ def get_xpu_device_type(x): return "pvc" else: return "others" + + +def load_imatrix_data(imatrix_file): + # this function is adapted from https://github.com/ggerganov/llama.cpp/blob/ + # c82d18e863fcde91b4b1109b1d0c73ea4470c405/examples/quantize/quantize.cpp#L102 + imatrix = open(imatrix_file, 'rb') + n_entries = imatrix.read(4) + n_entries = int.from_bytes(n_entries, 'little') + invalidInputError(n_entries >= 1, + f"failed reading name for entry from {imatrix_file}") + imatrix_data = {} + for i in range(n_entries): + cur_len = imatrix.read(4) + cur_len = int.from_bytes(cur_len, 'little') + cur_name = str(imatrix.read(cur_len), encoding='utf-8') + # original cur_name looks like blk.14.attn_output.weight for llama + # TODO: how to better aligned and generalize + name_list = cur_name.split('.') + layer = name_list[1] + module_name = name_list[2] + if 'ffn' in module_name: + module_name = module_name[4:] # from ffn_gate to gate + elif 'attn' in module_name: + module_name = module_name[5] # from attn_k to k, attn_output to o + module_name = layer + '_' + module_name + ncall = imatrix.read(4) + ncall = int.from_bytes(ncall, 'little') + nval = imatrix.read(4) + nval = int.from_bytes(nval, 'little') + invalidInputError(nval >= 1, + f"failed reading number of values for entry {i}") + byte_data = imatrix.read(4 * nval) + idata = np.frombuffer(byte_data, dtype=np.float32) + + if ncall > 0: + idata = idata / ncall + imatrix_data[module_name] = torch.from_numpy(idata).float() + + print(f"loaded {len(imatrix_data)} importance matrix entries from {imatrix_file}.") + return imatrix_data + + +def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data): + if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]: + # For quantization which needs importance matrix + # module name preprocess + # full name maybe model.layers.31.self_attn.o_proj + # TODO: just consider llama/mistral here + # TODO: how to better aligned and generalize + module_name = full_module_name.split('.') + cur_qtype = qtype + if len(module_name) == 5: + layer = module_name[2] + cur_module = module_name[-1][:-5] + new_module_name = '_'.join([layer, cur_module]) + elif len(module_name) == 1: + new_module_name = module_name[0] + layer = None + cur_module = None + if imatrix_data is not None and new_module_name in imatrix_data: + cur_imatrix = imatrix_data[new_module_name] + # custom mixed quantization strategy + if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \ + or new_module_name == 'lm_head': + cur_qtype = ggml_tensor_qtype['sym_int4'] + else: + cur_imatrix = None + # custom mixed quantization strategy + if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \ + or new_module_name == 'lm_head': + cur_qtype = ggml_tensor_qtype['sym_int4'] + else: + cur_imatrix = None + cur_qtype = qtype + + return cur_qtype, cur_imatrix