LLM: 2bit quantization initial support (#10042)
* basis quantize support * fix new module name * small update * and mixed int4 with iq2_xxs * remove print * code refactor * fix style * meet code review
This commit is contained in:
parent
f440cb4fba
commit
d61f4905ac
6 changed files with 164 additions and 22 deletions
|
|
@ -965,6 +965,30 @@ _lib.ggml_quantize_tensor.argtypes = [
|
|||
_lib.ggml_quantize_tensor.restype = ctypes.c_size_t
|
||||
|
||||
|
||||
def ggml_quantize_tensor_with_weights(
|
||||
src, # type: ctypes.Array[ctypes.c_float] # type: ignore
|
||||
dst: ctypes.c_void_p,
|
||||
qtype: ctypes.c_int,
|
||||
nrow: ctypes.c_int,
|
||||
n_per_row: ctypes.c_int,
|
||||
hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore
|
||||
weights, # type: ctypes.Array[ctypes.c_float] # type: ignore
|
||||
) -> int:
|
||||
return _lib.ggml_quantize_tensor_with_weights(src, dst, qtype, nrow, n_per_row, hist, weights)
|
||||
|
||||
|
||||
_lib.ggml_quantize_tensor_with_weights.argtypes = [
|
||||
ctypes.POINTER(ctypes.c_float),
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.POINTER(ctypes.c_int64),
|
||||
ctypes.POINTER(ctypes.c_float),
|
||||
]
|
||||
_lib.ggml_quantize_tensor_with_weights.restype = ctypes.c_size_t
|
||||
|
||||
|
||||
def ggml_type_size(qtype: ctypes.c_int) -> int:
|
||||
return _lib.ggml_type_size(qtype)
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,9 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
|
|||
"mixed_fp8": 18, # Mixture of Formats Quantization 8 bits
|
||||
"fp8_e5m2": 19, # fp8 in e5m2 format
|
||||
"fp8": 19, # fp8 in e5m2 format
|
||||
"bf16": 20}
|
||||
"bf16": 20,
|
||||
"iq2_xxs": 21,
|
||||
"iq2_xs": 22}
|
||||
|
||||
_llama_quantize_type = {"q4_0": 2,
|
||||
"q4_1": 3,
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ import warnings
|
|||
import transformers
|
||||
import importlib.util
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from .utils import logger
|
||||
from .utils import logger, get_cur_qtype_and_imatrix
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
import os
|
||||
|
|
@ -190,7 +190,8 @@ def convert_gptq(module, awq=False, llm_awq=False):
|
|||
|
||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||
current_key_name=None, convert_shape_only=False,
|
||||
cpu_embedding=False, prefix_name=''):
|
||||
cpu_embedding=False, prefix_name='',
|
||||
imatrix_data=None):
|
||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
||||
FP16Linear, BF16Linear
|
||||
from bigdl.llm.transformers.embedding import LLMEmbedding
|
||||
|
|
@ -248,7 +249,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
)
|
||||
|
||||
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
|
||||
full_module_name,
|
||||
imatrix_data)
|
||||
device = module.weight.data.device
|
||||
# Copy the weights
|
||||
paramsLowBit = FP4Params(data=module.weight.data,
|
||||
|
|
@ -256,7 +259,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
quantized=False,
|
||||
_shape=None,
|
||||
convert_shape_only=convert_shape_only,
|
||||
qtype=qtype).to(device)
|
||||
qtype=cur_qtype,
|
||||
imatrix=cur_imatrix,
|
||||
in_features=in_features).to(device)
|
||||
new_linear._parameters['weight'] = paramsLowBit
|
||||
if module.bias is not None:
|
||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||
|
|
@ -328,7 +333,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
current_key_name,
|
||||
convert_shape_only,
|
||||
cpu_embedding,
|
||||
prefix_name=prefix_name + '.' + name if prefix_name != '' else name
|
||||
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
|
||||
imatrix_data=imatrix_data
|
||||
)
|
||||
has_been_replaced = _flag or has_been_replaced
|
||||
return model, has_been_replaced
|
||||
|
|
@ -505,7 +511,8 @@ def _optimize_pre(model):
|
|||
def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||
convert_shape_only=False, device="cpu",
|
||||
modules_to_not_convert=None, cpu_embedding=False,
|
||||
lightweight_bmm=False, torch_dtype="auto"):
|
||||
lightweight_bmm=False, torch_dtype="auto",
|
||||
imatrix_data=None):
|
||||
logger.info(f"Converting the current model to "
|
||||
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
|
||||
f"format......")
|
||||
|
|
@ -517,6 +524,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
model, has_been_replaced = _replace_with_low_bit_linear(
|
||||
model, qtype, modules_to_not_convert,
|
||||
None, convert_shape_only, cpu_embedding,
|
||||
imatrix_data=imatrix_data,
|
||||
)
|
||||
if not has_been_replaced:
|
||||
warnings.warn(
|
||||
|
|
|
|||
|
|
@ -70,6 +70,8 @@ FP4 = ggml_tensor_qtype["fp4"]
|
|||
MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
|
||||
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
|
||||
FP8E5 = ggml_tensor_qtype["fp8_e5m2"]
|
||||
IQ2_XXS = ggml_tensor_qtype["iq2_xxs"]
|
||||
IQ2_XS = ggml_tensor_qtype["iq2_xs"]
|
||||
|
||||
|
||||
def get_block_size(qtype: str):
|
||||
|
|
@ -81,7 +83,9 @@ def get_qk_size(qtype: int):
|
|||
|
||||
|
||||
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||
device=None, convert_shape_only=False):
|
||||
device=None, convert_shape_only=False,
|
||||
imatrix: torch.Tensor=None,
|
||||
in_features: int=None):
|
||||
QK = ggml.ggml_qk_size(qtype)
|
||||
block_size_in_bytes = ggml.ggml_type_size(qtype)
|
||||
|
||||
|
|
@ -89,12 +93,10 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
|||
"Input tensor must be float32")
|
||||
src = tensor.data.data_ptr()
|
||||
src = ctypes.cast(src, ctypes.POINTER(ctypes.c_float))
|
||||
n = tensor.numel()
|
||||
invalidInputError(n % QK == 0,
|
||||
"Input tensor size must be multiple of 64")
|
||||
n = tensor.numel() # all elements
|
||||
k = tensor.shape[-1]
|
||||
invalidInputError(k % QK == 0,
|
||||
"Last dim of input tensor must be multiple of 64")
|
||||
f"Last dim of input tensor must be multiple of {QK}")
|
||||
|
||||
dst_size = (n // QK) * block_size_in_bytes
|
||||
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
|
||||
|
|
@ -103,7 +105,16 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
|||
if not convert_shape_only and device != 'meta':
|
||||
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
||||
hist = (ctypes.c_int64 * 16)()
|
||||
if qtype not in [IQ2_XXS, IQ2_XS]:
|
||||
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
|
||||
else:
|
||||
# quantize with importance matrix
|
||||
imatrix = imatrix.data.data_ptr()
|
||||
imatrix = ctypes.cast(imatrix, ctypes.POINTER(ctypes.c_float))
|
||||
# pass nrow and n_per_row
|
||||
ggml.ggml_quantize_tensor_with_weights(src, dst, qtype,
|
||||
n // in_features, in_features,
|
||||
hist, imatrix)
|
||||
return dst_tensor
|
||||
|
||||
|
||||
|
|
@ -193,7 +204,9 @@ class FP4Params(torch.nn.Parameter):
|
|||
quantized=False,
|
||||
_shape=None,
|
||||
convert_shape_only=False,
|
||||
qtype=None):
|
||||
qtype=None,
|
||||
imatrix=None,
|
||||
in_features=None):
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
|
||||
|
|
@ -203,6 +216,8 @@ class FP4Params(torch.nn.Parameter):
|
|||
self._shape = _shape
|
||||
self.qtype = qtype
|
||||
self.convert_shape_only = convert_shape_only
|
||||
self.imatrix = imatrix
|
||||
self.in_features = in_features
|
||||
return self
|
||||
|
||||
def ggml_mse(self, w, ggml_qtype, device):
|
||||
|
|
@ -255,7 +270,9 @@ class FP4Params(torch.nn.Parameter):
|
|||
else:
|
||||
w_quantized = ggml_convert_qtype(w, self.qtype,
|
||||
device=device,
|
||||
convert_shape_only=self.convert_shape_only)
|
||||
convert_shape_only=self.convert_shape_only,
|
||||
imatrix=self.imatrix,
|
||||
in_features=self.in_features)
|
||||
self.data = w_quantized
|
||||
self.quantized = True
|
||||
self._shape = w.shape
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ import transformers
|
|||
from transformers.configuration_utils import PretrainedConfig
|
||||
from .utils import extract_local_archive_file, \
|
||||
load_state_dict, \
|
||||
get_local_shard_files
|
||||
get_local_shard_files, load_imatrix_data
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.gguf.api import load_gguf_model
|
||||
|
|
@ -107,10 +107,10 @@ class _BaseAutoModelClass:
|
|||
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``,
|
||||
``'sym_int5'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``,
|
||||
``'nf4'``, ``'fp4'``, ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``,
|
||||
``'fp16'`` or ``'bf16'``, ``'sym_int4'`` means symmetric int 4,
|
||||
``'asym_int4'`` means asymmetric int 4, ``'nf4'`` means 4-bit
|
||||
NormalFloat, etc. Relevant low bit optimizations will be applied
|
||||
to the model.
|
||||
``'iq2_xxs'``, ``'iq2_xs'``, ``'fp16'`` or ``'bf16'``,
|
||||
``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means
|
||||
asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc.
|
||||
Relevant low bit optimizations will be applied to the model.
|
||||
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
|
||||
Default to be ``True``.
|
||||
:param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when
|
||||
|
|
@ -121,6 +121,9 @@ class _BaseAutoModelClass:
|
|||
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
|
||||
:param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it
|
||||
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
|
||||
:param imatrix: str value, represent filename of importance matrix pretrained on
|
||||
specific datasets for use with the improved quantization methods recently
|
||||
added to llama.cpp.
|
||||
:return: a model instance
|
||||
"""
|
||||
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
||||
|
|
@ -243,6 +246,12 @@ class _BaseAutoModelClass:
|
|||
else:
|
||||
kwargs["pretraining_tp"] = 1
|
||||
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
||||
if q_k in ["iq2_xxs", "iq2_xs"]:
|
||||
imatrix_file = kwargs.pop("imatrix", None)
|
||||
invalidInputError(imatrix_file is not None,
|
||||
"For iq2_xxs and iq2_xs quantization, imatrix is needed.")
|
||||
imatrix_data = load_imatrix_data(imatrix_file)
|
||||
kwargs['imatrix_data'] = imatrix_data
|
||||
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
|
||||
|
||||
if speculative:
|
||||
|
|
@ -285,7 +294,8 @@ class _BaseAutoModelClass:
|
|||
invalidInputError(q_k in ggml_tensor_qtype,
|
||||
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
||||
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
|
||||
"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, mixed_fp4 or mixed_fp8.")
|
||||
f"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, iq2_xxs, iq2_xs, "
|
||||
f"mixed_fp4 or mixed_fp8.")
|
||||
qtype = ggml_tensor_qtype[q_k]
|
||||
|
||||
# In case it needs a second try,
|
||||
|
|
@ -299,6 +309,7 @@ class _BaseAutoModelClass:
|
|||
cpu_embedding = True
|
||||
lightweight_bmm = kwargs.pop("lightweight_bmm", False)
|
||||
quant_config = kwargs.pop("quantization_config", None)
|
||||
imatrix_data = kwargs.pop("imatrix_data", None)
|
||||
_args = copy.deepcopy(args)
|
||||
_kwargs = copy.deepcopy(kwargs)
|
||||
awq_config = None
|
||||
|
|
@ -359,7 +370,8 @@ class _BaseAutoModelClass:
|
|||
model = ggml_convert_low_bit(model, qtype, optimize_model,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
|
||||
torch_dtype=kwargs.get("torch_dtype", 'auto'))
|
||||
torch_dtype=kwargs.get("torch_dtype", 'auto'),
|
||||
imatrix_data=imatrix_data)
|
||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||
|
||||
# enable tie_word_embeddings for MPT
|
||||
|
|
|
|||
|
|
@ -41,11 +41,14 @@
|
|||
# SOFTWARE.
|
||||
import os
|
||||
from transformers.modeling_utils import _add_variant
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from ..utils.common import invalidInputError
|
||||
from typing import Union
|
||||
import torch
|
||||
from torch import nn
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -179,3 +182,79 @@ def get_xpu_device_type(x):
|
|||
return "pvc"
|
||||
else:
|
||||
return "others"
|
||||
|
||||
|
||||
def load_imatrix_data(imatrix_file):
|
||||
# this function is adapted from https://github.com/ggerganov/llama.cpp/blob/
|
||||
# c82d18e863fcde91b4b1109b1d0c73ea4470c405/examples/quantize/quantize.cpp#L102
|
||||
imatrix = open(imatrix_file, 'rb')
|
||||
n_entries = imatrix.read(4)
|
||||
n_entries = int.from_bytes(n_entries, 'little')
|
||||
invalidInputError(n_entries >= 1,
|
||||
f"failed reading name for entry from {imatrix_file}")
|
||||
imatrix_data = {}
|
||||
for i in range(n_entries):
|
||||
cur_len = imatrix.read(4)
|
||||
cur_len = int.from_bytes(cur_len, 'little')
|
||||
cur_name = str(imatrix.read(cur_len), encoding='utf-8')
|
||||
# original cur_name looks like blk.14.attn_output.weight for llama
|
||||
# TODO: how to better aligned and generalize
|
||||
name_list = cur_name.split('.')
|
||||
layer = name_list[1]
|
||||
module_name = name_list[2]
|
||||
if 'ffn' in module_name:
|
||||
module_name = module_name[4:] # from ffn_gate to gate
|
||||
elif 'attn' in module_name:
|
||||
module_name = module_name[5] # from attn_k to k, attn_output to o
|
||||
module_name = layer + '_' + module_name
|
||||
ncall = imatrix.read(4)
|
||||
ncall = int.from_bytes(ncall, 'little')
|
||||
nval = imatrix.read(4)
|
||||
nval = int.from_bytes(nval, 'little')
|
||||
invalidInputError(nval >= 1,
|
||||
f"failed reading number of values for entry {i}")
|
||||
byte_data = imatrix.read(4 * nval)
|
||||
idata = np.frombuffer(byte_data, dtype=np.float32)
|
||||
|
||||
if ncall > 0:
|
||||
idata = idata / ncall
|
||||
imatrix_data[module_name] = torch.from_numpy(idata).float()
|
||||
|
||||
print(f"loaded {len(imatrix_data)} importance matrix entries from {imatrix_file}.")
|
||||
return imatrix_data
|
||||
|
||||
|
||||
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
|
||||
if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]:
|
||||
# For quantization which needs importance matrix
|
||||
# module name preprocess
|
||||
# full name maybe model.layers.31.self_attn.o_proj
|
||||
# TODO: just consider llama/mistral here
|
||||
# TODO: how to better aligned and generalize
|
||||
module_name = full_module_name.split('.')
|
||||
cur_qtype = qtype
|
||||
if len(module_name) == 5:
|
||||
layer = module_name[2]
|
||||
cur_module = module_name[-1][:-5]
|
||||
new_module_name = '_'.join([layer, cur_module])
|
||||
elif len(module_name) == 1:
|
||||
new_module_name = module_name[0]
|
||||
layer = None
|
||||
cur_module = None
|
||||
if imatrix_data is not None and new_module_name in imatrix_data:
|
||||
cur_imatrix = imatrix_data[new_module_name]
|
||||
# custom mixed quantization strategy
|
||||
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \
|
||||
or new_module_name == 'lm_head':
|
||||
cur_qtype = ggml_tensor_qtype['sym_int4']
|
||||
else:
|
||||
cur_imatrix = None
|
||||
# custom mixed quantization strategy
|
||||
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \
|
||||
or new_module_name == 'lm_head':
|
||||
cur_qtype = ggml_tensor_qtype['sym_int4']
|
||||
else:
|
||||
cur_imatrix = None
|
||||
cur_qtype = qtype
|
||||
|
||||
return cur_qtype, cur_imatrix
|
||||
|
|
|
|||
Loading…
Reference in a new issue