add fallback for unsupported k-quants (#11691)
* add fallback * fix style * fix
This commit is contained in:
parent
5079ed9e06
commit
54bf3a23a6
2 changed files with 25 additions and 1 deletions
|
|
@ -44,7 +44,7 @@ import warnings
|
|||
import transformers
|
||||
import importlib.util
|
||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype, gguf_mixed_qtype
|
||||
from .utils import logger, get_cur_qtype_and_imatrix
|
||||
from .utils import logger, get_cur_qtype_and_imatrix, check_hidden_size
|
||||
import numpy as np
|
||||
import os
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
|
|
@ -396,6 +396,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
ggml_tensor_qtype["asym_int4"]]:
|
||||
cur_qtype = ggml_tensor_qtype["sym_int8"]
|
||||
|
||||
# check hidden size whether is a multiple of 256
|
||||
cur_qtype = check_hidden_size(cur_qtype, in_features)
|
||||
|
||||
new_linear = LowBitLinear(
|
||||
in_features,
|
||||
out_features,
|
||||
|
|
|
|||
|
|
@ -361,3 +361,24 @@ def get_modelscope_hf_config(model_id_or_path: str,
|
|||
def is_torch_bf16_gpu_available():
|
||||
# always true for XPU and CPU
|
||||
return True
|
||||
|
||||
|
||||
def check_hidden_size(qtype, hidden_size):
|
||||
if hidden_size % 256 != 0:
|
||||
if qtype == ggml_tensor_qtype["q4_k"]:
|
||||
logger.info(f"hidden size {hidden_size} is not divisible by 256, "
|
||||
"required for q4_k - using fallback quantization asym_int4.")
|
||||
return ggml_tensor_qtype["asym_int4"]
|
||||
elif qtype == ggml_tensor_qtype["q5_k"]:
|
||||
logger.info(f"hidden size {hidden_size} is not divisible by 256, "
|
||||
"required for q5_k - using fallback quantization asym_int5.")
|
||||
return ggml_tensor_qtype["asym_int5"]
|
||||
elif qtype == ggml_tensor_qtype["q6_k"]:
|
||||
logger.info(f"hidden size {hidden_size} is not divisible by 256, "
|
||||
"required for q6_k - using fallback quantization sym_int8.")
|
||||
return ggml_tensor_qtype["sym_int8"]
|
||||
elif qtype == ggml_tensor_qtype["fp6_k"]:
|
||||
logger.info(f"hidden size {hidden_size} is not divisible by 256, "
|
||||
"required for fq6_k - using fallback quantization fp6.")
|
||||
return ggml_tensor_qtype["fp6"]
|
||||
return qtype
|
||||
|
|
|
|||
Loading…
Reference in a new issue