LLM: 2bit quantization initial support (#10042)

* basis quantize support

* fix new module name

* small update

* and mixed int4 with iq2_xxs

* remove print

* code refactor

* fix style

* meet code review
This commit is contained in:
Ruonan Wang 2024-02-06 14:58:32 +08:00 committed by GitHub
parent f440cb4fba
commit d61f4905ac
6 changed files with 164 additions and 22 deletions

View file

@ -965,6 +965,30 @@ _lib.ggml_quantize_tensor.argtypes = [
_lib.ggml_quantize_tensor.restype = ctypes.c_size_t _lib.ggml_quantize_tensor.restype = ctypes.c_size_t
def ggml_quantize_tensor_with_weights(
src, # type: ctypes.Array[ctypes.c_float] # type: ignore
dst: ctypes.c_void_p,
qtype: ctypes.c_int,
nrow: ctypes.c_int,
n_per_row: ctypes.c_int,
hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore
weights, # type: ctypes.Array[ctypes.c_float] # type: ignore
) -> int:
return _lib.ggml_quantize_tensor_with_weights(src, dst, qtype, nrow, n_per_row, hist, weights)
_lib.ggml_quantize_tensor_with_weights.argtypes = [
ctypes.POINTER(ctypes.c_float),
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.POINTER(ctypes.c_int64),
ctypes.POINTER(ctypes.c_float),
]
_lib.ggml_quantize_tensor_with_weights.restype = ctypes.c_size_t
def ggml_type_size(qtype: ctypes.c_int) -> int: def ggml_type_size(qtype: ctypes.c_int) -> int:
return _lib.ggml_type_size(qtype) return _lib.ggml_type_size(qtype)

View file

@ -39,7 +39,9 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
"mixed_fp8": 18, # Mixture of Formats Quantization 8 bits "mixed_fp8": 18, # Mixture of Formats Quantization 8 bits
"fp8_e5m2": 19, # fp8 in e5m2 format "fp8_e5m2": 19, # fp8 in e5m2 format
"fp8": 19, # fp8 in e5m2 format "fp8": 19, # fp8 in e5m2 format
"bf16": 20} "bf16": 20,
"iq2_xxs": 21,
"iq2_xs": 22}
_llama_quantize_type = {"q4_0": 2, _llama_quantize_type = {"q4_0": 2,
"q4_1": 3, "q4_1": 3,

View file

@ -43,7 +43,7 @@ import warnings
import transformers import transformers
import importlib.util import importlib.util
from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from .utils import logger from .utils import logger, get_cur_qtype_and_imatrix
from typing import Union from typing import Union
import numpy as np import numpy as np
import os import os
@ -190,7 +190,8 @@ def convert_gptq(module, awq=False, llm_awq=False):
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name=None, convert_shape_only=False, current_key_name=None, convert_shape_only=False,
cpu_embedding=False, prefix_name=''): cpu_embedding=False, prefix_name='',
imatrix_data=None):
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \ from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
FP16Linear, BF16Linear FP16Linear, BF16Linear
from bigdl.llm.transformers.embedding import LLMEmbedding from bigdl.llm.transformers.embedding import LLMEmbedding
@ -248,7 +249,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
module.bias is not None, module.bias is not None,
mp_group=mp_group, mp_group=mp_group,
) )
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
full_module_name,
imatrix_data)
device = module.weight.data.device device = module.weight.data.device
# Copy the weights # Copy the weights
paramsLowBit = FP4Params(data=module.weight.data, paramsLowBit = FP4Params(data=module.weight.data,
@ -256,7 +259,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
quantized=False, quantized=False,
_shape=None, _shape=None,
convert_shape_only=convert_shape_only, convert_shape_only=convert_shape_only,
qtype=qtype).to(device) qtype=cur_qtype,
imatrix=cur_imatrix,
in_features=in_features).to(device)
new_linear._parameters['weight'] = paramsLowBit new_linear._parameters['weight'] = paramsLowBit
if module.bias is not None: if module.bias is not None:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
@ -328,7 +333,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name, current_key_name,
convert_shape_only, convert_shape_only,
cpu_embedding, cpu_embedding,
prefix_name=prefix_name + '.' + name if prefix_name != '' else name prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
imatrix_data=imatrix_data
) )
has_been_replaced = _flag or has_been_replaced has_been_replaced = _flag or has_been_replaced
return model, has_been_replaced return model, has_been_replaced
@ -505,7 +511,8 @@ def _optimize_pre(model):
def ggml_convert_low_bit(model, qtype, optimize_model=True, def ggml_convert_low_bit(model, qtype, optimize_model=True,
convert_shape_only=False, device="cpu", convert_shape_only=False, device="cpu",
modules_to_not_convert=None, cpu_embedding=False, modules_to_not_convert=None, cpu_embedding=False,
lightweight_bmm=False, torch_dtype="auto"): lightweight_bmm=False, torch_dtype="auto",
imatrix_data=None):
logger.info(f"Converting the current model to " logger.info(f"Converting the current model to "
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} " f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
f"format......") f"format......")
@ -517,6 +524,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
model, has_been_replaced = _replace_with_low_bit_linear( model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert, model, qtype, modules_to_not_convert,
None, convert_shape_only, cpu_embedding, None, convert_shape_only, cpu_embedding,
imatrix_data=imatrix_data,
) )
if not has_been_replaced: if not has_been_replaced:
warnings.warn( warnings.warn(

View file

@ -70,6 +70,8 @@ FP4 = ggml_tensor_qtype["fp4"]
MOFQ4 = ggml_tensor_qtype["mixed_fp4"] MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
MOFQ8 = ggml_tensor_qtype["mixed_fp8"] MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
FP8E5 = ggml_tensor_qtype["fp8_e5m2"] FP8E5 = ggml_tensor_qtype["fp8_e5m2"]
IQ2_XXS = ggml_tensor_qtype["iq2_xxs"]
IQ2_XS = ggml_tensor_qtype["iq2_xs"]
def get_block_size(qtype: str): def get_block_size(qtype: str):
@ -81,7 +83,9 @@ def get_qk_size(qtype: int):
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
device=None, convert_shape_only=False): device=None, convert_shape_only=False,
imatrix: torch.Tensor=None,
in_features: int=None):
QK = ggml.ggml_qk_size(qtype) QK = ggml.ggml_qk_size(qtype)
block_size_in_bytes = ggml.ggml_type_size(qtype) block_size_in_bytes = ggml.ggml_type_size(qtype)
@ -89,12 +93,10 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
"Input tensor must be float32") "Input tensor must be float32")
src = tensor.data.data_ptr() src = tensor.data.data_ptr()
src = ctypes.cast(src, ctypes.POINTER(ctypes.c_float)) src = ctypes.cast(src, ctypes.POINTER(ctypes.c_float))
n = tensor.numel() n = tensor.numel() # all elements
invalidInputError(n % QK == 0,
"Input tensor size must be multiple of 64")
k = tensor.shape[-1] k = tensor.shape[-1]
invalidInputError(k % QK == 0, invalidInputError(k % QK == 0,
"Last dim of input tensor must be multiple of 64") f"Last dim of input tensor must be multiple of {QK}")
dst_size = (n // QK) * block_size_in_bytes dst_size = (n // QK) * block_size_in_bytes
dst_tensor = torch.empty(dst_size, dtype=torch.uint8, dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
@ -103,7 +105,16 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
if not convert_shape_only and device != 'meta': if not convert_shape_only and device != 'meta':
dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
hist = (ctypes.c_int64 * 16)() hist = (ctypes.c_int64 * 16)()
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist) if qtype not in [IQ2_XXS, IQ2_XS]:
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
else:
# quantize with importance matrix
imatrix = imatrix.data.data_ptr()
imatrix = ctypes.cast(imatrix, ctypes.POINTER(ctypes.c_float))
# pass nrow and n_per_row
ggml.ggml_quantize_tensor_with_weights(src, dst, qtype,
n // in_features, in_features,
hist, imatrix)
return dst_tensor return dst_tensor
@ -193,7 +204,9 @@ class FP4Params(torch.nn.Parameter):
quantized=False, quantized=False,
_shape=None, _shape=None,
convert_shape_only=False, convert_shape_only=False,
qtype=None): qtype=None,
imatrix=None,
in_features=None):
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
@ -203,6 +216,8 @@ class FP4Params(torch.nn.Parameter):
self._shape = _shape self._shape = _shape
self.qtype = qtype self.qtype = qtype
self.convert_shape_only = convert_shape_only self.convert_shape_only = convert_shape_only
self.imatrix = imatrix
self.in_features = in_features
return self return self
def ggml_mse(self, w, ggml_qtype, device): def ggml_mse(self, w, ggml_qtype, device):
@ -255,7 +270,9 @@ class FP4Params(torch.nn.Parameter):
else: else:
w_quantized = ggml_convert_qtype(w, self.qtype, w_quantized = ggml_convert_qtype(w, self.qtype,
device=device, device=device,
convert_shape_only=self.convert_shape_only) convert_shape_only=self.convert_shape_only,
imatrix=self.imatrix,
in_features=self.in_features)
self.data = w_quantized self.data = w_quantized
self.quantized = True self.quantized = True
self._shape = w.shape self._shape = w.shape

View file

@ -41,7 +41,7 @@ import transformers
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from .utils import extract_local_archive_file, \ from .utils import extract_local_archive_file, \
load_state_dict, \ load_state_dict, \
get_local_shard_files get_local_shard_files, load_imatrix_data
from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.gguf.api import load_gguf_model from bigdl.llm.transformers.gguf.api import load_gguf_model
@ -107,10 +107,10 @@ class _BaseAutoModelClass:
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``, :param load_in_low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``,
``'sym_int5'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``, ``'sym_int5'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``,
``'nf4'``, ``'fp4'``, ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``, ``'nf4'``, ``'fp4'``, ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``,
``'fp16'`` or ``'bf16'``, ``'sym_int4'`` means symmetric int 4, ``'iq2_xxs'``, ``'iq2_xs'``, ``'fp16'`` or ``'bf16'``,
``'asym_int4'`` means asymmetric int 4, ``'nf4'`` means 4-bit ``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means
NormalFloat, etc. Relevant low bit optimizations will be applied asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc.
to the model. Relevant low bit optimizations will be applied to the model.
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model. :param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
Default to be ``True``. Default to be ``True``.
:param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when :param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when
@ -121,6 +121,9 @@ class _BaseAutoModelClass:
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``. to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
:param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it :param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``. to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
:param imatrix: str value, represent filename of importance matrix pretrained on
specific datasets for use with the improved quantization methods recently
added to llama.cpp.
:return: a model instance :return: a model instance
""" """
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \ pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
@ -243,6 +246,12 @@ class _BaseAutoModelClass:
else: else:
kwargs["pretraining_tp"] = 1 kwargs["pretraining_tp"] = 1
q_k = load_in_low_bit if load_in_low_bit else "sym_int4" q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
if q_k in ["iq2_xxs", "iq2_xs"]:
imatrix_file = kwargs.pop("imatrix", None)
invalidInputError(imatrix_file is not None,
"For iq2_xxs and iq2_xs quantization, imatrix is needed.")
imatrix_data = load_imatrix_data(imatrix_file)
kwargs['imatrix_data'] = imatrix_data
model = cls.load_convert(q_k, optimize_model, *args, **kwargs) model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
if speculative: if speculative:
@ -285,7 +294,8 @@ class _BaseAutoModelClass:
invalidInputError(q_k in ggml_tensor_qtype, invalidInputError(q_k in ggml_tensor_qtype,
f"Unknown load_in_low_bit value: {q_k}, expected:" f"Unknown load_in_low_bit value: {q_k}, expected:"
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, " f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, mixed_fp4 or mixed_fp8.") f"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, iq2_xxs, iq2_xs, "
f"mixed_fp4 or mixed_fp8.")
qtype = ggml_tensor_qtype[q_k] qtype = ggml_tensor_qtype[q_k]
# In case it needs a second try, # In case it needs a second try,
@ -299,6 +309,7 @@ class _BaseAutoModelClass:
cpu_embedding = True cpu_embedding = True
lightweight_bmm = kwargs.pop("lightweight_bmm", False) lightweight_bmm = kwargs.pop("lightweight_bmm", False)
quant_config = kwargs.pop("quantization_config", None) quant_config = kwargs.pop("quantization_config", None)
imatrix_data = kwargs.pop("imatrix_data", None)
_args = copy.deepcopy(args) _args = copy.deepcopy(args)
_kwargs = copy.deepcopy(kwargs) _kwargs = copy.deepcopy(kwargs)
awq_config = None awq_config = None
@ -359,7 +370,8 @@ class _BaseAutoModelClass:
model = ggml_convert_low_bit(model, qtype, optimize_model, model = ggml_convert_low_bit(model, qtype, optimize_model,
modules_to_not_convert=modules_to_not_convert, modules_to_not_convert=modules_to_not_convert,
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm, cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
torch_dtype=kwargs.get("torch_dtype", 'auto')) torch_dtype=kwargs.get("torch_dtype", 'auto'),
imatrix_data=imatrix_data)
model.config.update({"bigdl_transformers_low_bit": q_k}) model.config.update({"bigdl_transformers_low_bit": q_k})
# enable tie_word_embeddings for MPT # enable tie_word_embeddings for MPT

View file

@ -41,11 +41,14 @@
# SOFTWARE. # SOFTWARE.
import os import os
from transformers.modeling_utils import _add_variant from transformers.modeling_utils import _add_variant
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from ..utils.common import invalidInputError from ..utils.common import invalidInputError
from typing import Union from typing import Union
import torch import torch
from torch import nn from torch import nn
import logging import logging
import numpy as np
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -179,3 +182,79 @@ def get_xpu_device_type(x):
return "pvc" return "pvc"
else: else:
return "others" return "others"
def load_imatrix_data(imatrix_file):
# this function is adapted from https://github.com/ggerganov/llama.cpp/blob/
# c82d18e863fcde91b4b1109b1d0c73ea4470c405/examples/quantize/quantize.cpp#L102
imatrix = open(imatrix_file, 'rb')
n_entries = imatrix.read(4)
n_entries = int.from_bytes(n_entries, 'little')
invalidInputError(n_entries >= 1,
f"failed reading name for entry from {imatrix_file}")
imatrix_data = {}
for i in range(n_entries):
cur_len = imatrix.read(4)
cur_len = int.from_bytes(cur_len, 'little')
cur_name = str(imatrix.read(cur_len), encoding='utf-8')
# original cur_name looks like blk.14.attn_output.weight for llama
# TODO: how to better aligned and generalize
name_list = cur_name.split('.')
layer = name_list[1]
module_name = name_list[2]
if 'ffn' in module_name:
module_name = module_name[4:] # from ffn_gate to gate
elif 'attn' in module_name:
module_name = module_name[5] # from attn_k to k, attn_output to o
module_name = layer + '_' + module_name
ncall = imatrix.read(4)
ncall = int.from_bytes(ncall, 'little')
nval = imatrix.read(4)
nval = int.from_bytes(nval, 'little')
invalidInputError(nval >= 1,
f"failed reading number of values for entry {i}")
byte_data = imatrix.read(4 * nval)
idata = np.frombuffer(byte_data, dtype=np.float32)
if ncall > 0:
idata = idata / ncall
imatrix_data[module_name] = torch.from_numpy(idata).float()
print(f"loaded {len(imatrix_data)} importance matrix entries from {imatrix_file}.")
return imatrix_data
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]:
# For quantization which needs importance matrix
# module name preprocess
# full name maybe model.layers.31.self_attn.o_proj
# TODO: just consider llama/mistral here
# TODO: how to better aligned and generalize
module_name = full_module_name.split('.')
cur_qtype = qtype
if len(module_name) == 5:
layer = module_name[2]
cur_module = module_name[-1][:-5]
new_module_name = '_'.join([layer, cur_module])
elif len(module_name) == 1:
new_module_name = module_name[0]
layer = None
cur_module = None
if imatrix_data is not None and new_module_name in imatrix_data:
cur_imatrix = imatrix_data[new_module_name]
# custom mixed quantization strategy
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \
or new_module_name == 'lm_head':
cur_qtype = ggml_tensor_qtype['sym_int4']
else:
cur_imatrix = None
# custom mixed quantization strategy
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \
or new_module_name == 'lm_head':
cur_qtype = ggml_tensor_qtype['sym_int4']
else:
cur_imatrix = None
cur_qtype = qtype
return cur_qtype, cur_imatrix