support gguf_q4k_m / gguf_q4k_s (#10887)
* initial commit * UPDATE * fix style * fix style * add gguf_q4k_s * update comment * fix
This commit is contained in:
parent
981d668be6
commit
f1156e6b20
5 changed files with 82 additions and 35 deletions
|
|
@ -47,8 +47,13 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
|
||||||
"gguf_iq1_m": 25,
|
"gguf_iq1_m": 25,
|
||||||
"q6_k": 26,
|
"q6_k": 26,
|
||||||
"q4_k": 27,
|
"q4_k": 27,
|
||||||
|
"q5_k": 28,
|
||||||
"fp6": 29}
|
"fp6": 29}
|
||||||
|
|
||||||
|
# mixed precison from llama.cpp
|
||||||
|
gguf_mixed_qtype = {"gguf_q4k_s": 101,
|
||||||
|
"gguf_q4k_m": 102}
|
||||||
|
|
||||||
_llama_quantize_type = {"q4_0": 2,
|
_llama_quantize_type = {"q4_0": 2,
|
||||||
"q4_1": 3,
|
"q4_1": 3,
|
||||||
"q5_0": 8,
|
"q5_0": 8,
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ from accelerate import init_empty_weights
|
||||||
import warnings
|
import warnings
|
||||||
import transformers
|
import transformers
|
||||||
import importlib.util
|
import importlib.util
|
||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
|
||||||
from .utils import logger, get_cur_qtype_and_imatrix
|
from .utils import logger, get_cur_qtype_and_imatrix
|
||||||
from typing import Union
|
from typing import Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -337,15 +337,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
if in_features % 64 != 0:
|
if in_features % 64 != 0:
|
||||||
# now our kernel requires in_features is a multiple of 64
|
# now our kernel requires in_features is a multiple of 64
|
||||||
continue
|
continue
|
||||||
new_linear = LowBitLinear(
|
|
||||||
in_features,
|
|
||||||
out_features,
|
|
||||||
qtype,
|
|
||||||
module.bias is not None,
|
|
||||||
mp_group=mp_group,
|
|
||||||
enable_xetla=enable_xetla,
|
|
||||||
optimize_lm_head=optimize_lm_head
|
|
||||||
)
|
|
||||||
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
|
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
|
||||||
full_module_name,
|
full_module_name,
|
||||||
imatrix_data,
|
imatrix_data,
|
||||||
|
|
@ -355,6 +346,16 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
if cur_qtype in [ggml_tensor_qtype["sym_int4"],
|
if cur_qtype in [ggml_tensor_qtype["sym_int4"],
|
||||||
ggml_tensor_qtype["asym_int4"]]:
|
ggml_tensor_qtype["asym_int4"]]:
|
||||||
cur_qtype = ggml_tensor_qtype["sym_int8"]
|
cur_qtype = ggml_tensor_qtype["sym_int8"]
|
||||||
|
|
||||||
|
new_linear = LowBitLinear(
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
cur_qtype,
|
||||||
|
module.bias is not None,
|
||||||
|
mp_group=mp_group,
|
||||||
|
enable_xetla=enable_xetla,
|
||||||
|
optimize_lm_head=optimize_lm_head
|
||||||
|
)
|
||||||
device = module.weight.data.device
|
device = module.weight.data.device
|
||||||
# Copy the weights
|
# Copy the weights
|
||||||
paramsLowBit = FP4Params(data=module.weight.data,
|
paramsLowBit = FP4Params(data=module.weight.data,
|
||||||
|
|
@ -766,9 +767,16 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
embedding_qtype=None,
|
embedding_qtype=None,
|
||||||
enable_xetla=False,
|
enable_xetla=False,
|
||||||
mixed_precision=False):
|
mixed_precision=False):
|
||||||
logger.info(f"Converting the current model to "
|
if qtype in ggml_tensor_qtype.values():
|
||||||
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
|
index = list(ggml_tensor_qtype.values()).index(qtype)
|
||||||
f"format......")
|
logger.info(f"Converting the current model to "
|
||||||
|
f"{list(ggml_tensor_qtype.keys())[index]} "
|
||||||
|
f"format......")
|
||||||
|
else:
|
||||||
|
index = list(gguf_mixed_qtype.values()).index(qtype)
|
||||||
|
logger.info(f"Converting the current model to "
|
||||||
|
f"{list(gguf_mixed_qtype.keys())[index]} "
|
||||||
|
f"format......")
|
||||||
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
|
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
|
||||||
|
|
||||||
# using ipex_llm optimizer before changing to bigdl linear
|
# using ipex_llm optimizer before changing to bigdl linear
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,7 @@ Q2_K = ggml_tensor_qtype["q2_k"]
|
||||||
IQ1_S = ggml_tensor_qtype["gguf_iq1_s"]
|
IQ1_S = ggml_tensor_qtype["gguf_iq1_s"]
|
||||||
Q4_K = ggml_tensor_qtype["q4_k"]
|
Q4_K = ggml_tensor_qtype["q4_k"]
|
||||||
Q6_K = ggml_tensor_qtype["q6_k"]
|
Q6_K = ggml_tensor_qtype["q6_k"]
|
||||||
|
Q5_K = ggml_tensor_qtype["q5_k"]
|
||||||
|
|
||||||
|
|
||||||
# For sym_int4
|
# For sym_int4
|
||||||
|
|
@ -219,7 +220,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||||
if not convert_shape_only and device != 'meta':
|
if not convert_shape_only and device != 'meta':
|
||||||
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
||||||
hist = (ctypes.c_int64 * 16)()
|
hist = (ctypes.c_int64 * 16)()
|
||||||
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K]:
|
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K]:
|
||||||
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
|
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
|
||||||
else:
|
else:
|
||||||
if imatrix is not None:
|
if imatrix is not None:
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||||
from .utils import extract_local_archive_file, \
|
from .utils import extract_local_archive_file, \
|
||||||
load_state_dict, \
|
load_state_dict, \
|
||||||
get_local_shard_files, load_imatrix_data
|
get_local_shard_files, load_imatrix_data
|
||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from ipex_llm.transformers.gguf.api import load_gguf_model
|
from ipex_llm.transformers.gguf.api import load_gguf_model
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -117,12 +117,12 @@ class _BaseAutoModelClass:
|
||||||
Default to be ``False``.
|
Default to be ``False``.
|
||||||
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``,
|
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``,
|
||||||
``'sym_int5'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``,
|
``'sym_int5'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``,
|
||||||
``'nf4'``, ``'fp4'``, ``'fp6'`` ``'fp8'``, ``'fp8_e4m3'``,
|
``'nf4'``, ``'fp4'``, ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``,
|
||||||
``'fp8_e5m2'``, ``'gguf_iq2_xxs'``, ``'gguf_iq2_xs'``,
|
``'fp6'``, ``'gguf_iq2_xxs'``, ``'gguf_iq2_xs'``,
|
||||||
``'gguf_iq1_s'``, ``'fp16'``, ``'bf16'``, ``'q4_k'`` or
|
``'gguf_iq1_s'``, ``'gguf_q4k_m'``, ``'gguf_q4k_s'``,
|
||||||
``'q6_k'``, ``'sym_int4'`` means symmetric int 4,
|
``'fp16'``, ``'bf16'``,
|
||||||
``'asym_int4'`` means asymmetric int 4,
|
``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means
|
||||||
``'nf4'`` means 4-bit NormalFloat, etc.
|
asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc.
|
||||||
Relevant low bit optimizations will be applied to the model.
|
Relevant low bit optimizations will be applied to the model.
|
||||||
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
|
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
|
||||||
Default to be ``True``.
|
Default to be ``True``.
|
||||||
|
|
@ -139,8 +139,9 @@ class _BaseAutoModelClass:
|
||||||
added to llama.cpp.
|
added to llama.cpp.
|
||||||
:param model_hub: str value, options are ``'huggingface'`` and ``'modelscope'``,
|
:param model_hub: str value, options are ``'huggingface'`` and ``'modelscope'``,
|
||||||
specify the model hub. Default to be ``'huggingface'``.
|
specify the model hub. Default to be ``'huggingface'``.
|
||||||
:param embedding_qtype: str value, options are ``'q2_k'`` now. Default to be None.
|
:param embedding_qtype: str value, options are ``'q2_k'``, ``'q4_k'`` now.
|
||||||
Relevant low bit optimizations will be applied to nn.Embedding layer.
|
Default to be None. Relevant low bit optimizations will be applied to
|
||||||
|
``nn.Embedding`` layer.
|
||||||
:param mixed_precision: boolean value, Whether to use mixed precision quantization.
|
:param mixed_precision: boolean value, Whether to use mixed precision quantization.
|
||||||
Default to be False. If set to True, we will use sym_int8 for lm_head when
|
Default to be False. If set to True, we will use sym_int8 for lm_head when
|
||||||
load_in_low_bit is sym_int4 or asym_int4.
|
load_in_low_bit is sym_int4 or asym_int4.
|
||||||
|
|
@ -321,10 +322,12 @@ class _BaseAutoModelClass:
|
||||||
"For gguf_iq2 and gguf_iq1 quantization,"
|
"For gguf_iq2 and gguf_iq1 quantization,"
|
||||||
"imatrix is needed.")
|
"imatrix is needed.")
|
||||||
cpu_embedding = kwargs.get("cpu_embedding", False)
|
cpu_embedding = kwargs.get("cpu_embedding", False)
|
||||||
# for 2bit, default use embedding_quantization
|
# for iq2/k-quants, default use embedding_quantization
|
||||||
if q_k in ["gguf_iq2_xxs", "gguf_iq2_xs", "gguf_iq1_s", "q2_k"] and \
|
if not cpu_embedding and embedding_qtype is None:
|
||||||
not cpu_embedding and embedding_qtype is None:
|
if q_k in ["gguf_iq2_xxs", "gguf_iq2_xs", "gguf_iq1_s", "q2_k"]:
|
||||||
embedding_qtype = "q2_k"
|
embedding_qtype = "q2_k"
|
||||||
|
elif q_k in ["gguf_q4k_s", "gguf_q4k_m"]:
|
||||||
|
embedding_qtype = "q4_k"
|
||||||
if imatrix_file is not None:
|
if imatrix_file is not None:
|
||||||
imatrix_data = load_imatrix_data(imatrix_file)
|
imatrix_data = load_imatrix_data(imatrix_file)
|
||||||
kwargs["imatrix_data"] = imatrix_data
|
kwargs["imatrix_data"] = imatrix_data
|
||||||
|
|
@ -376,12 +379,16 @@ class _BaseAutoModelClass:
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_convert(cls, q_k, optimize_model, *args, **kwargs):
|
def load_convert(cls, q_k, optimize_model, *args, **kwargs):
|
||||||
from .convert import ggml_convert_low_bit
|
from .convert import ggml_convert_low_bit
|
||||||
invalidInputError(q_k in ggml_tensor_qtype,
|
invalidInputError(q_k in ggml_tensor_qtype or q_k in gguf_mixed_qtype,
|
||||||
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
||||||
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
|
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
|
||||||
f"fp4, fp6, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, gguf_iq2_xxs, "
|
f"fp4, fp6, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, gguf_iq2_xxs, "
|
||||||
f"gguf_iq2_xs, gguf_iq1_s, q2_k, q4_k, q6_k, mixed_fp4 or mixed_fp8.")
|
f"gguf_iq2_xs, gguf_iq1_s, q2_k, q4_k, q5_k, q6_k, "
|
||||||
qtype = ggml_tensor_qtype[q_k]
|
f"gguf_q4k_s, gguf_q4k_m, mixed_fp4 or mixed_fp8.")
|
||||||
|
if q_k in ggml_tensor_qtype:
|
||||||
|
qtype = ggml_tensor_qtype[q_k]
|
||||||
|
else:
|
||||||
|
qtype = gguf_mixed_qtype[q_k]
|
||||||
|
|
||||||
# In case it needs a second try,
|
# In case it needs a second try,
|
||||||
# `from_pretrained`` may pop items out in dict
|
# `from_pretrained`` may pop items out in dict
|
||||||
|
|
@ -550,17 +557,24 @@ class _BaseAutoModelClass:
|
||||||
" with load_in_4bit or load_in_low_bit to get a low-bit model , and "
|
" with load_in_4bit or load_in_low_bit to get a low-bit model , and "
|
||||||
" serialize the model using save_low_bit first.")
|
" serialize the model using save_low_bit first.")
|
||||||
|
|
||||||
invalidInputError(bigdl_transformers_low_bit in ggml_tensor_qtype,
|
invalidInputError(bigdl_transformers_low_bit in ggml_tensor_qtype or
|
||||||
|
bigdl_transformers_low_bit in gguf_mixed_qtype,
|
||||||
f"Unknown bigdl_transformers_low_bit value: {bigdl_transformers_low_bit},"
|
f"Unknown bigdl_transformers_low_bit value: {bigdl_transformers_low_bit},"
|
||||||
f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
||||||
|
|
||||||
# set default optimize_model=True
|
# set default optimize_model=True
|
||||||
optimize_model = kwargs.pop("optimize_model", True)
|
optimize_model = kwargs.pop("optimize_model", True)
|
||||||
|
|
||||||
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
|
if bigdl_transformers_low_bit in ggml_tensor_qtype:
|
||||||
|
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
|
||||||
|
else:
|
||||||
|
qtype = gguf_mixed_qtype[bigdl_transformers_low_bit]
|
||||||
if bigdl_transformers_low_bit in ["gguf_iq2_xxs", "gguf_iq2_xs", "gguf_iq1_s", "q2_k"] and \
|
if bigdl_transformers_low_bit in ["gguf_iq2_xxs", "gguf_iq2_xs", "gguf_iq1_s", "q2_k"] and \
|
||||||
not cpu_embedding:
|
not cpu_embedding:
|
||||||
embedding_qtype = "q2_k"
|
embedding_qtype = "q2_k"
|
||||||
|
elif bigdl_transformers_low_bit in ["gguf_q4k_s", "gguf_q4k_m"] and \
|
||||||
|
not cpu_embedding:
|
||||||
|
embedding_qtype = "q4_k"
|
||||||
if embedding_qtype is not None:
|
if embedding_qtype is not None:
|
||||||
embedding_qtype = ggml_tensor_qtype[embedding_qtype]
|
embedding_qtype = ggml_tensor_qtype[embedding_qtype]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
import os
|
import os
|
||||||
from transformers.modeling_utils import _add_variant
|
from transformers.modeling_utils import _add_variant
|
||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
|
||||||
from ..utils.common import invalidInputError
|
from ..utils.common import invalidInputError
|
||||||
from typing import Union, Optional
|
from typing import Union, Optional
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -271,10 +271,12 @@ def module_name_process(full_module_name):
|
||||||
|
|
||||||
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_config=None):
|
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_config=None):
|
||||||
cur_qtype = qtype
|
cur_qtype = qtype
|
||||||
|
cur_imatrix = None
|
||||||
if model_config is not None:
|
if model_config is not None:
|
||||||
model_type = getattr(model_config, "model_type", None)
|
model_type = getattr(model_config, "model_type", None)
|
||||||
else:
|
else:
|
||||||
model_dtype = None
|
model_dtype = None
|
||||||
|
|
||||||
if qtype in [ggml_tensor_qtype["gguf_iq2_xxs"], ggml_tensor_qtype["gguf_iq2_xs"],
|
if qtype in [ggml_tensor_qtype["gguf_iq2_xxs"], ggml_tensor_qtype["gguf_iq2_xs"],
|
||||||
ggml_tensor_qtype["gguf_iq1_s"]]:
|
ggml_tensor_qtype["gguf_iq1_s"]]:
|
||||||
# For quantization which needs importance matrix
|
# For quantization which needs importance matrix
|
||||||
|
|
@ -306,7 +308,6 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
|
||||||
cur_imatrix = None
|
cur_imatrix = None
|
||||||
if new_module_name == 'lm_head':
|
if new_module_name == 'lm_head':
|
||||||
cur_qtype = ggml_tensor_qtype['sym_int8']
|
cur_qtype = ggml_tensor_qtype['sym_int8']
|
||||||
return cur_qtype, cur_imatrix
|
|
||||||
elif qtype == ggml_tensor_qtype["q2_k"]:
|
elif qtype == ggml_tensor_qtype["q2_k"]:
|
||||||
new_module_name, layer, cur_module = module_name_process(full_module_name)
|
new_module_name, layer, cur_module = module_name_process(full_module_name)
|
||||||
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
|
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
|
||||||
|
|
@ -319,8 +320,26 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
|
||||||
cur_imatrix = None
|
cur_imatrix = None
|
||||||
if new_module_name == 'lm_head':
|
if new_module_name == 'lm_head':
|
||||||
cur_qtype = ggml_tensor_qtype['sym_int8']
|
cur_qtype = ggml_tensor_qtype['sym_int8']
|
||||||
|
elif qtype > 100:
|
||||||
|
# gguf mixed precision
|
||||||
|
new_module_name, layer, cur_module = module_name_process(full_module_name)
|
||||||
|
num_hidden_layers = getattr(model_config, "num_hidden_layers", None)
|
||||||
|
if qtype in [gguf_mixed_qtype["gguf_q4k_s"], gguf_mixed_qtype["gguf_q4k_m"]] and \
|
||||||
|
new_module_name == 'lm_head':
|
||||||
|
cur_qtype = ggml_tensor_qtype['q6_k']
|
||||||
|
elif qtype == gguf_mixed_qtype["gguf_q4k_m"]:
|
||||||
|
if int(layer) < int(num_hidden_layers/2) and cur_module in ['v', 'down']:
|
||||||
|
cur_qtype = ggml_tensor_qtype['q6_k']
|
||||||
|
else:
|
||||||
|
cur_qtype = ggml_tensor_qtype['q4_k']
|
||||||
|
elif qtype == gguf_mixed_qtype["gguf_q4k_s"]:
|
||||||
|
if int(layer) < int(num_hidden_layers/8) and cur_module in ['v', 'down']:
|
||||||
|
cur_qtype = ggml_tensor_qtype['q5_k']
|
||||||
|
else:
|
||||||
|
cur_qtype = ggml_tensor_qtype['q4_k']
|
||||||
else:
|
else:
|
||||||
return qtype, None
|
pass
|
||||||
|
return cur_qtype, cur_imatrix
|
||||||
|
|
||||||
|
|
||||||
def get_modelscope_hf_config(model_id_or_path: str,
|
def get_modelscope_hf_config(model_id_or_path: str,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue