LLM: basic support for q2k (#10132)

* basic support for q2k

* fix style
This commit is contained in:
Ruonan Wang 2024-02-08 13:52:01 +08:00 committed by GitHub
parent 11fe5a87ec
commit 063dc145ac
4 changed files with 14 additions and 11 deletions

View file

@ -41,7 +41,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
"fp8": 19, # fp8 in e5m2 format "fp8": 19, # fp8 in e5m2 format
"bf16": 20, "bf16": 20,
"iq2_xxs": 21, "iq2_xxs": 21,
"iq2_xs": 22} "iq2_xs": 22,
"q2_k": 23}
_llama_quantize_type = {"q4_0": 2, _llama_quantize_type = {"q4_0": 2,
"q4_1": 3, "q4_1": 3,

View file

@ -72,6 +72,7 @@ MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
FP8E5 = ggml_tensor_qtype["fp8_e5m2"] FP8E5 = ggml_tensor_qtype["fp8_e5m2"]
IQ2_XXS = ggml_tensor_qtype["iq2_xxs"] IQ2_XXS = ggml_tensor_qtype["iq2_xxs"]
IQ2_XS = ggml_tensor_qtype["iq2_xs"] IQ2_XS = ggml_tensor_qtype["iq2_xs"]
Q2_K = ggml_tensor_qtype["q2_k"]
def get_block_size(qtype: str): def get_block_size(qtype: str):
@ -105,12 +106,13 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
if not convert_shape_only and device != 'meta': if not convert_shape_only and device != 'meta':
dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
hist = (ctypes.c_int64 * 16)() hist = (ctypes.c_int64 * 16)()
if qtype not in [IQ2_XXS, IQ2_XS]: if qtype not in [IQ2_XXS, IQ2_XS, Q2_K]:
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist) ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
else: else:
# quantize with importance matrix if imatrix is not None:
imatrix = imatrix.data.data_ptr() # quantize with importance matrix
imatrix = ctypes.cast(imatrix, ctypes.POINTER(ctypes.c_float)) imatrix = imatrix.data.data_ptr()
imatrix = ctypes.cast(imatrix, ctypes.POINTER(ctypes.c_float))
# pass nrow and n_per_row # pass nrow and n_per_row
ggml.ggml_quantize_tensor_with_weights(src, dst, qtype, ggml.ggml_quantize_tensor_with_weights(src, dst, qtype,
n // in_features, in_features, n // in_features, in_features,

View file

@ -271,10 +271,11 @@ class _BaseAutoModelClass:
else: else:
kwargs["pretraining_tp"] = 1 kwargs["pretraining_tp"] = 1
q_k = load_in_low_bit if load_in_low_bit else "sym_int4" q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
imatrix_file = kwargs.pop("imatrix", None)
if q_k in ["iq2_xxs", "iq2_xs"]: if q_k in ["iq2_xxs", "iq2_xs"]:
imatrix_file = kwargs.pop("imatrix", None)
invalidInputError(imatrix_file is not None, invalidInputError(imatrix_file is not None,
"For iq2_xxs and iq2_xs quantization, imatrix is needed.") "For iq2_xxs and iq2_xs quantization, imatrix is needed.")
if imatrix_file is not None:
imatrix_data = load_imatrix_data(imatrix_file) imatrix_data = load_imatrix_data(imatrix_file)
kwargs['imatrix_data'] = imatrix_data kwargs['imatrix_data'] = imatrix_data
model = cls.load_convert(q_k, optimize_model, *args, **kwargs) model = cls.load_convert(q_k, optimize_model, *args, **kwargs)

View file

@ -225,7 +225,8 @@ def load_imatrix_data(imatrix_file):
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data): def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]: if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"],
ggml_tensor_qtype["q2_k"]] and imatrix_data is not None:
# For quantization which needs importance matrix # For quantization which needs importance matrix
# module name preprocess # module name preprocess
# full name maybe model.layers.31.self_attn.o_proj # full name maybe model.layers.31.self_attn.o_proj
@ -253,11 +254,9 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \ if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \
or new_module_name == 'lm_head': or new_module_name == 'lm_head':
cur_qtype = ggml_tensor_qtype['sym_int4'] cur_qtype = ggml_tensor_qtype['sym_int4']
return cur_qtype, cur_imatrix
else: else:
cur_imatrix = None return qtype, None
cur_qtype = qtype
return cur_qtype, cur_imatrix
def get_modelscope_hf_config(model_id_or_path: str, def get_modelscope_hf_config(model_id_or_path: str,