diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 8594a4b4..376c1793 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -92,6 +92,13 @@ if is_auto_awq_available(): from transformers.utils.quantization_config import AwqBackendPackingMethod +def is_lm_head(name, model_config, out_features): + if name == "lm_head" or getattr(model_config, "vocab_size", None) == out_features: + return True + else: + return False + + def is_linear_module(module): in_features = None @@ -220,7 +227,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, cpu_embedding=False, prefix_name='', imatrix_data=None, embedding_qtype=None, model_config=None, torch_dtype=torch.float32, - enable_xetla=False): + enable_xetla=False, + mixed_precision=False): from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \ FP16Linear, BF16Linear from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding @@ -237,7 +245,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if is_linear and not isinstance(module, LowBitLinear): in_features, out_features, mp_group = linear_args optimize_lm_head = False - if name == "lm_head": + if is_lm_head(name, model_config, out_features): model_type = getattr(model_config, "model_type", None) if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD", None) == "1": @@ -291,6 +299,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, full_module_name, imatrix_data, model_config) + # mixed precison for lm_head + if mixed_precision and is_lm_head(name, model_config, out_features): + if cur_qtype in [ggml_tensor_qtype["sym_int4"], + ggml_tensor_qtype["asym_int4"]]: + cur_qtype = ggml_tensor_qtype["sym_int8"] device = module.weight.data.device # Copy the weights paramsLowBit = FP4Params(data=module.weight.data, @@ -409,6 +422,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, model_config=model_config, torch_dtype=torch_dtype, enable_xetla=enable_xetla, + mixed_precision=mixed_precision ) has_been_replaced = _flag or has_been_replaced return model, has_been_replaced @@ -684,7 +698,8 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, lightweight_bmm=False, torch_dtype="auto", imatrix_data=None, embedding_qtype=None, - enable_xetla=False): + enable_xetla=False, + mixed_precision=False): logger.info(f"Converting the current model to " f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} " f"format......") @@ -709,6 +724,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, model_config=getattr(model, "config", None), torch_dtype=torch_dtype, enable_xetla=enable_xetla, + mixed_precision=mixed_precision, ) if not has_been_replaced: warnings.warn( diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index 34b4b0c0..b1604b6c 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -140,6 +140,9 @@ class _BaseAutoModelClass: specify the model hub. Default to be ``'huggingface'``. :param embedding_qtype: str value, options are ``'q2_k'`` now. Default to be None. Relevant low bit optimizations will be applied to nn.Embedding layer. + :param mixed_precision: boolean value, Whether to use mixed precision quantization. + Default to be False. If set to True, we will use sym_int8 for lm_head when + load_in_low_bit is sym_int4 or asym_int4. :return: a model instance """ pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \ @@ -394,6 +397,7 @@ class _BaseAutoModelClass: quant_config = kwargs.pop("quantization_config", None) imatrix_data = kwargs.pop("imatrix_data", None) embedding_qtype = kwargs.pop("embedding_qtype", None) + mixed_precision = kwargs.pop("mixed_precision", False) if embedding_qtype is not None: embedding_qtype = ggml_tensor_qtype[embedding_qtype] enable_xetla = kwargs.pop("enable_xetla", False) @@ -463,7 +467,8 @@ class _BaseAutoModelClass: torch_dtype=kwargs.get("torch_dtype", 'auto'), imatrix_data=imatrix_data, embedding_qtype=embedding_qtype, - enable_xetla=enable_xetla,) + enable_xetla=enable_xetla, + mixed_precision=mixed_precision) model.config.update({"bigdl_transformers_low_bit": q_k}) # enable tie_word_embeddings for MPT