LLM: add mixed precision for lm_head (#10795)

* add mixed_quantization

* meet code review

* update

* fix style

* meet review
This commit is contained in:
Ruonan Wang 2024-04-18 19:11:31 +08:00 committed by GitHub
parent 8796401b08
commit 439c834ed3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 4 deletions

View file

@ -92,6 +92,13 @@ if is_auto_awq_available():
from transformers.utils.quantization_config import AwqBackendPackingMethod
def is_lm_head(name, model_config, out_features):
if name == "lm_head" or getattr(model_config, "vocab_size", None) == out_features:
return True
else:
return False
def is_linear_module(module):
in_features = None
@ -220,7 +227,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
cpu_embedding=False, prefix_name='',
imatrix_data=None, embedding_qtype=None,
model_config=None, torch_dtype=torch.float32,
enable_xetla=False):
enable_xetla=False,
mixed_precision=False):
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
FP16Linear, BF16Linear
from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
@ -237,7 +245,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if is_linear and not isinstance(module, LowBitLinear):
in_features, out_features, mp_group = linear_args
optimize_lm_head = False
if name == "lm_head":
if is_lm_head(name, model_config, out_features):
model_type = getattr(model_config, "model_type", None)
if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD",
None) == "1":
@ -291,6 +299,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
full_module_name,
imatrix_data,
model_config)
# mixed precison for lm_head
if mixed_precision and is_lm_head(name, model_config, out_features):
if cur_qtype in [ggml_tensor_qtype["sym_int4"],
ggml_tensor_qtype["asym_int4"]]:
cur_qtype = ggml_tensor_qtype["sym_int8"]
device = module.weight.data.device
# Copy the weights
paramsLowBit = FP4Params(data=module.weight.data,
@ -409,6 +422,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
model_config=model_config,
torch_dtype=torch_dtype,
enable_xetla=enable_xetla,
mixed_precision=mixed_precision
)
has_been_replaced = _flag or has_been_replaced
return model, has_been_replaced
@ -684,7 +698,8 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
lightweight_bmm=False, torch_dtype="auto",
imatrix_data=None,
embedding_qtype=None,
enable_xetla=False):
enable_xetla=False,
mixed_precision=False):
logger.info(f"Converting the current model to "
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
f"format......")
@ -709,6 +724,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
model_config=getattr(model, "config", None),
torch_dtype=torch_dtype,
enable_xetla=enable_xetla,
mixed_precision=mixed_precision,
)
if not has_been_replaced:
warnings.warn(

View file

@ -140,6 +140,9 @@ class _BaseAutoModelClass:
specify the model hub. Default to be ``'huggingface'``.
:param embedding_qtype: str value, options are ``'q2_k'`` now. Default to be None.
Relevant low bit optimizations will be applied to nn.Embedding layer.
:param mixed_precision: boolean value, Whether to use mixed precision quantization.
Default to be False. If set to True, we will use sym_int8 for lm_head when
load_in_low_bit is sym_int4 or asym_int4.
:return: a model instance
"""
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
@ -394,6 +397,7 @@ class _BaseAutoModelClass:
quant_config = kwargs.pop("quantization_config", None)
imatrix_data = kwargs.pop("imatrix_data", None)
embedding_qtype = kwargs.pop("embedding_qtype", None)
mixed_precision = kwargs.pop("mixed_precision", False)
if embedding_qtype is not None:
embedding_qtype = ggml_tensor_qtype[embedding_qtype]
enable_xetla = kwargs.pop("enable_xetla", False)
@ -463,7 +467,8 @@ class _BaseAutoModelClass:
torch_dtype=kwargs.get("torch_dtype", 'auto'),
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
enable_xetla=enable_xetla,)
enable_xetla=enable_xetla,
mixed_precision=mixed_precision)
model.config.update({"bigdl_transformers_low_bit": q_k})
# enable tie_word_embeddings for MPT