From f1156e6b20d2f5895f7e4ed24fbc62e20e4ccdda Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Fri, 17 May 2024 06:30:09 +0000 Subject: [PATCH] support gguf_q4k_m / gguf_q4k_s (#10887) * initial commit * UPDATE * fix style * fix style * add gguf_q4k_s * update comment * fix --- python/llm/src/ipex_llm/ggml/quantize.py | 5 ++ .../llm/src/ipex_llm/transformers/convert.py | 34 ++++++++----- .../ipex_llm/transformers/low_bit_linear.py | 3 +- python/llm/src/ipex_llm/transformers/model.py | 50 ++++++++++++------- python/llm/src/ipex_llm/transformers/utils.py | 25 ++++++++-- 5 files changed, 82 insertions(+), 35 deletions(-) diff --git a/python/llm/src/ipex_llm/ggml/quantize.py b/python/llm/src/ipex_llm/ggml/quantize.py index bdaeccaf..3eaf668e 100644 --- a/python/llm/src/ipex_llm/ggml/quantize.py +++ b/python/llm/src/ipex_llm/ggml/quantize.py @@ -47,8 +47,13 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml "gguf_iq1_m": 25, "q6_k": 26, "q4_k": 27, + "q5_k": 28, "fp6": 29} +# mixed precison from llama.cpp +gguf_mixed_qtype = {"gguf_q4k_s": 101, + "gguf_q4k_m": 102} + _llama_quantize_type = {"q4_0": 2, "q4_1": 3, "q5_0": 8, diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index d1ce9f43..639a2154 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -42,7 +42,7 @@ from accelerate import init_empty_weights import warnings import transformers import importlib.util -from ipex_llm.ggml.quantize import ggml_tensor_qtype +from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype from .utils import logger, get_cur_qtype_and_imatrix from typing import Union import numpy as np @@ -337,15 +337,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if in_features % 64 != 0: # now our kernel requires in_features is a multiple of 64 continue - new_linear = LowBitLinear( - in_features, - out_features, - qtype, - module.bias is not None, - mp_group=mp_group, - enable_xetla=enable_xetla, - optimize_lm_head=optimize_lm_head - ) cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, @@ -355,6 +346,16 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if cur_qtype in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["asym_int4"]]: cur_qtype = ggml_tensor_qtype["sym_int8"] + + new_linear = LowBitLinear( + in_features, + out_features, + cur_qtype, + module.bias is not None, + mp_group=mp_group, + enable_xetla=enable_xetla, + optimize_lm_head=optimize_lm_head + ) device = module.weight.data.device # Copy the weights paramsLowBit = FP4Params(data=module.weight.data, @@ -766,9 +767,16 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, embedding_qtype=None, enable_xetla=False, mixed_precision=False): - logger.info(f"Converting the current model to " - f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} " - f"format......") + if qtype in ggml_tensor_qtype.values(): + index = list(ggml_tensor_qtype.values()).index(qtype) + logger.info(f"Converting the current model to " + f"{list(ggml_tensor_qtype.keys())[index]} " + f"format......") + else: + index = list(gguf_mixed_qtype.values()).index(qtype) + logger.info(f"Converting the current model to " + f"{list(gguf_mixed_qtype.keys())[index]} " + f"format......") modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert # using ipex_llm optimizer before changing to bigdl linear diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index af8eb04d..f129093f 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -79,6 +79,7 @@ Q2_K = ggml_tensor_qtype["q2_k"] IQ1_S = ggml_tensor_qtype["gguf_iq1_s"] Q4_K = ggml_tensor_qtype["q4_k"] Q6_K = ggml_tensor_qtype["q6_k"] +Q5_K = ggml_tensor_qtype["q5_k"] # For sym_int4 @@ -219,7 +220,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, if not convert_shape_only and device != 'meta': dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) hist = (ctypes.c_int64 * 16)() - if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K]: + if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K]: ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist) else: if imatrix is not None: diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index e3c5448c..00e7a2f3 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -42,7 +42,7 @@ from transformers.configuration_utils import PretrainedConfig from .utils import extract_local_archive_file, \ load_state_dict, \ get_local_shard_files, load_imatrix_data -from ipex_llm.ggml.quantize import ggml_tensor_qtype +from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.gguf.api import load_gguf_model import torch @@ -117,12 +117,12 @@ class _BaseAutoModelClass: Default to be ``False``. :param load_in_low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``, ``'sym_int5'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``, - ``'nf4'``, ``'fp4'``, ``'fp6'`` ``'fp8'``, ``'fp8_e4m3'``, - ``'fp8_e5m2'``, ``'gguf_iq2_xxs'``, ``'gguf_iq2_xs'``, - ``'gguf_iq1_s'``, ``'fp16'``, ``'bf16'``, ``'q4_k'`` or - ``'q6_k'``, ``'sym_int4'`` means symmetric int 4, - ``'asym_int4'`` means asymmetric int 4, - ``'nf4'`` means 4-bit NormalFloat, etc. + ``'nf4'``, ``'fp4'``, ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``, + ``'fp6'``, ``'gguf_iq2_xxs'``, ``'gguf_iq2_xs'``, + ``'gguf_iq1_s'``, ``'gguf_q4k_m'``, ``'gguf_q4k_s'``, + ``'fp16'``, ``'bf16'``, + ``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means + asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc. Relevant low bit optimizations will be applied to the model. :param optimize_model: boolean value, Whether to further optimize the low_bit llm model. Default to be ``True``. @@ -139,8 +139,9 @@ class _BaseAutoModelClass: added to llama.cpp. :param model_hub: str value, options are ``'huggingface'`` and ``'modelscope'``, specify the model hub. Default to be ``'huggingface'``. - :param embedding_qtype: str value, options are ``'q2_k'`` now. Default to be None. - Relevant low bit optimizations will be applied to nn.Embedding layer. + :param embedding_qtype: str value, options are ``'q2_k'``, ``'q4_k'`` now. + Default to be None. Relevant low bit optimizations will be applied to + ``nn.Embedding`` layer. :param mixed_precision: boolean value, Whether to use mixed precision quantization. Default to be False. If set to True, we will use sym_int8 for lm_head when load_in_low_bit is sym_int4 or asym_int4. @@ -321,10 +322,12 @@ class _BaseAutoModelClass: "For gguf_iq2 and gguf_iq1 quantization," "imatrix is needed.") cpu_embedding = kwargs.get("cpu_embedding", False) - # for 2bit, default use embedding_quantization - if q_k in ["gguf_iq2_xxs", "gguf_iq2_xs", "gguf_iq1_s", "q2_k"] and \ - not cpu_embedding and embedding_qtype is None: - embedding_qtype = "q2_k" + # for iq2/k-quants, default use embedding_quantization + if not cpu_embedding and embedding_qtype is None: + if q_k in ["gguf_iq2_xxs", "gguf_iq2_xs", "gguf_iq1_s", "q2_k"]: + embedding_qtype = "q2_k" + elif q_k in ["gguf_q4k_s", "gguf_q4k_m"]: + embedding_qtype = "q4_k" if imatrix_file is not None: imatrix_data = load_imatrix_data(imatrix_file) kwargs["imatrix_data"] = imatrix_data @@ -376,12 +379,16 @@ class _BaseAutoModelClass: @classmethod def load_convert(cls, q_k, optimize_model, *args, **kwargs): from .convert import ggml_convert_low_bit - invalidInputError(q_k in ggml_tensor_qtype, + invalidInputError(q_k in ggml_tensor_qtype or q_k in gguf_mixed_qtype, f"Unknown load_in_low_bit value: {q_k}, expected:" f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, " f"fp4, fp6, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, gguf_iq2_xxs, " - f"gguf_iq2_xs, gguf_iq1_s, q2_k, q4_k, q6_k, mixed_fp4 or mixed_fp8.") - qtype = ggml_tensor_qtype[q_k] + f"gguf_iq2_xs, gguf_iq1_s, q2_k, q4_k, q5_k, q6_k, " + f"gguf_q4k_s, gguf_q4k_m, mixed_fp4 or mixed_fp8.") + if q_k in ggml_tensor_qtype: + qtype = ggml_tensor_qtype[q_k] + else: + qtype = gguf_mixed_qtype[q_k] # In case it needs a second try, # `from_pretrained`` may pop items out in dict @@ -550,17 +557,24 @@ class _BaseAutoModelClass: " with load_in_4bit or load_in_low_bit to get a low-bit model , and " " serialize the model using save_low_bit first.") - invalidInputError(bigdl_transformers_low_bit in ggml_tensor_qtype, + invalidInputError(bigdl_transformers_low_bit in ggml_tensor_qtype or + bigdl_transformers_low_bit in gguf_mixed_qtype, f"Unknown bigdl_transformers_low_bit value: {bigdl_transformers_low_bit}," f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.") # set default optimize_model=True optimize_model = kwargs.pop("optimize_model", True) - qtype = ggml_tensor_qtype[bigdl_transformers_low_bit] + if bigdl_transformers_low_bit in ggml_tensor_qtype: + qtype = ggml_tensor_qtype[bigdl_transformers_low_bit] + else: + qtype = gguf_mixed_qtype[bigdl_transformers_low_bit] if bigdl_transformers_low_bit in ["gguf_iq2_xxs", "gguf_iq2_xs", "gguf_iq1_s", "q2_k"] and \ not cpu_embedding: embedding_qtype = "q2_k" + elif bigdl_transformers_low_bit in ["gguf_q4k_s", "gguf_q4k_m"] and \ + not cpu_embedding: + embedding_qtype = "q4_k" if embedding_qtype is not None: embedding_qtype = ggml_tensor_qtype[embedding_qtype] diff --git a/python/llm/src/ipex_llm/transformers/utils.py b/python/llm/src/ipex_llm/transformers/utils.py index 49894a2b..74e10244 100644 --- a/python/llm/src/ipex_llm/transformers/utils.py +++ b/python/llm/src/ipex_llm/transformers/utils.py @@ -41,7 +41,7 @@ # SOFTWARE. import os from transformers.modeling_utils import _add_variant -from ipex_llm.ggml.quantize import ggml_tensor_qtype +from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype from ..utils.common import invalidInputError from typing import Union, Optional import torch @@ -271,10 +271,12 @@ def module_name_process(full_module_name): def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_config=None): cur_qtype = qtype + cur_imatrix = None if model_config is not None: model_type = getattr(model_config, "model_type", None) else: model_dtype = None + if qtype in [ggml_tensor_qtype["gguf_iq2_xxs"], ggml_tensor_qtype["gguf_iq2_xs"], ggml_tensor_qtype["gguf_iq1_s"]]: # For quantization which needs importance matrix @@ -306,7 +308,6 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi cur_imatrix = None if new_module_name == 'lm_head': cur_qtype = ggml_tensor_qtype['sym_int8'] - return cur_qtype, cur_imatrix elif qtype == ggml_tensor_qtype["q2_k"]: new_module_name, layer, cur_module = module_name_process(full_module_name) if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]): @@ -319,8 +320,26 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi cur_imatrix = None if new_module_name == 'lm_head': cur_qtype = ggml_tensor_qtype['sym_int8'] + elif qtype > 100: + # gguf mixed precision + new_module_name, layer, cur_module = module_name_process(full_module_name) + num_hidden_layers = getattr(model_config, "num_hidden_layers", None) + if qtype in [gguf_mixed_qtype["gguf_q4k_s"], gguf_mixed_qtype["gguf_q4k_m"]] and \ + new_module_name == 'lm_head': + cur_qtype = ggml_tensor_qtype['q6_k'] + elif qtype == gguf_mixed_qtype["gguf_q4k_m"]: + if int(layer) < int(num_hidden_layers/2) and cur_module in ['v', 'down']: + cur_qtype = ggml_tensor_qtype['q6_k'] + else: + cur_qtype = ggml_tensor_qtype['q4_k'] + elif qtype == gguf_mixed_qtype["gguf_q4k_s"]: + if int(layer) < int(num_hidden_layers/8) and cur_module in ['v', 'down']: + cur_qtype = ggml_tensor_qtype['q5_k'] + else: + cur_qtype = ggml_tensor_qtype['q4_k'] else: - return qtype, None + pass + return cur_qtype, cur_imatrix def get_modelscope_hf_config(model_id_or_path: str,