add fallback for unsupported k-quants (#11691)

* add fallback

* fix style

* fix
This commit is contained in:
Ruonan Wang 2024-07-31 06:39:58 +03:00 committed by GitHub
parent 5079ed9e06
commit 54bf3a23a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 1 deletions

View file

@ -44,7 +44,7 @@ import warnings
import transformers
import importlib.util
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
from .utils import logger, get_cur_qtype_and_imatrix
from .utils import logger, get_cur_qtype_and_imatrix, check_hidden_size
import numpy as np
import os
from ipex_llm.utils.common import invalidInputError
@ -396,6 +396,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
ggml_tensor_qtype["asym_int4"]]:
cur_qtype = ggml_tensor_qtype["sym_int8"]
# check hidden size whether is a multiple of 256
cur_qtype = check_hidden_size(cur_qtype, in_features)
new_linear = LowBitLinear(
in_features,
out_features,

View file

@ -361,3 +361,24 @@ def get_modelscope_hf_config(model_id_or_path: str,
def is_torch_bf16_gpu_available():
# always true for XPU and CPU
return True
def check_hidden_size(qtype, hidden_size):
if hidden_size % 256 != 0:
if qtype == ggml_tensor_qtype["q4_k"]:
logger.info(f"hidden size {hidden_size} is not divisible by 256, "
"required for q4_k - using fallback quantization asym_int4.")
return ggml_tensor_qtype["asym_int4"]
elif qtype == ggml_tensor_qtype["q5_k"]:
logger.info(f"hidden size {hidden_size} is not divisible by 256, "
"required for q5_k - using fallback quantization asym_int5.")
return ggml_tensor_qtype["asym_int5"]
elif qtype == ggml_tensor_qtype["q6_k"]:
logger.info(f"hidden size {hidden_size} is not divisible by 256, "
"required for q6_k - using fallback quantization sym_int8.")
return ggml_tensor_qtype["sym_int8"]
elif qtype == ggml_tensor_qtype["fp6_k"]:
logger.info(f"hidden size {hidden_size} is not divisible by 256, "
"required for fq6_k - using fallback quantization fp6.")
return ggml_tensor_qtype["fp6"]
return qtype