LLM: add mixed precision for lm_head (#10795)
* add mixed_quantization * meet code review * update * fix style * meet review
This commit is contained in:
parent
8796401b08
commit
439c834ed3
2 changed files with 25 additions and 4 deletions
|
|
@ -92,6 +92,13 @@ if is_auto_awq_available():
|
|||
from transformers.utils.quantization_config import AwqBackendPackingMethod
|
||||
|
||||
|
||||
def is_lm_head(name, model_config, out_features):
|
||||
if name == "lm_head" or getattr(model_config, "vocab_size", None) == out_features:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_linear_module(module):
|
||||
|
||||
in_features = None
|
||||
|
|
@ -220,7 +227,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
cpu_embedding=False, prefix_name='',
|
||||
imatrix_data=None, embedding_qtype=None,
|
||||
model_config=None, torch_dtype=torch.float32,
|
||||
enable_xetla=False):
|
||||
enable_xetla=False,
|
||||
mixed_precision=False):
|
||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
||||
FP16Linear, BF16Linear
|
||||
from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
|
||||
|
|
@ -237,7 +245,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
if is_linear and not isinstance(module, LowBitLinear):
|
||||
in_features, out_features, mp_group = linear_args
|
||||
optimize_lm_head = False
|
||||
if name == "lm_head":
|
||||
if is_lm_head(name, model_config, out_features):
|
||||
model_type = getattr(model_config, "model_type", None)
|
||||
if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD",
|
||||
None) == "1":
|
||||
|
|
@ -291,6 +299,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
full_module_name,
|
||||
imatrix_data,
|
||||
model_config)
|
||||
# mixed precison for lm_head
|
||||
if mixed_precision and is_lm_head(name, model_config, out_features):
|
||||
if cur_qtype in [ggml_tensor_qtype["sym_int4"],
|
||||
ggml_tensor_qtype["asym_int4"]]:
|
||||
cur_qtype = ggml_tensor_qtype["sym_int8"]
|
||||
device = module.weight.data.device
|
||||
# Copy the weights
|
||||
paramsLowBit = FP4Params(data=module.weight.data,
|
||||
|
|
@ -409,6 +422,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
model_config=model_config,
|
||||
torch_dtype=torch_dtype,
|
||||
enable_xetla=enable_xetla,
|
||||
mixed_precision=mixed_precision
|
||||
)
|
||||
has_been_replaced = _flag or has_been_replaced
|
||||
return model, has_been_replaced
|
||||
|
|
@ -684,7 +698,8 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
lightweight_bmm=False, torch_dtype="auto",
|
||||
imatrix_data=None,
|
||||
embedding_qtype=None,
|
||||
enable_xetla=False):
|
||||
enable_xetla=False,
|
||||
mixed_precision=False):
|
||||
logger.info(f"Converting the current model to "
|
||||
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
|
||||
f"format......")
|
||||
|
|
@ -709,6 +724,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
model_config=getattr(model, "config", None),
|
||||
torch_dtype=torch_dtype,
|
||||
enable_xetla=enable_xetla,
|
||||
mixed_precision=mixed_precision,
|
||||
)
|
||||
if not has_been_replaced:
|
||||
warnings.warn(
|
||||
|
|
|
|||
|
|
@ -140,6 +140,9 @@ class _BaseAutoModelClass:
|
|||
specify the model hub. Default to be ``'huggingface'``.
|
||||
:param embedding_qtype: str value, options are ``'q2_k'`` now. Default to be None.
|
||||
Relevant low bit optimizations will be applied to nn.Embedding layer.
|
||||
:param mixed_precision: boolean value, Whether to use mixed precision quantization.
|
||||
Default to be False. If set to True, we will use sym_int8 for lm_head when
|
||||
load_in_low_bit is sym_int4 or asym_int4.
|
||||
:return: a model instance
|
||||
"""
|
||||
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
||||
|
|
@ -394,6 +397,7 @@ class _BaseAutoModelClass:
|
|||
quant_config = kwargs.pop("quantization_config", None)
|
||||
imatrix_data = kwargs.pop("imatrix_data", None)
|
||||
embedding_qtype = kwargs.pop("embedding_qtype", None)
|
||||
mixed_precision = kwargs.pop("mixed_precision", False)
|
||||
if embedding_qtype is not None:
|
||||
embedding_qtype = ggml_tensor_qtype[embedding_qtype]
|
||||
enable_xetla = kwargs.pop("enable_xetla", False)
|
||||
|
|
@ -463,7 +467,8 @@ class _BaseAutoModelClass:
|
|||
torch_dtype=kwargs.get("torch_dtype", 'auto'),
|
||||
imatrix_data=imatrix_data,
|
||||
embedding_qtype=embedding_qtype,
|
||||
enable_xetla=enable_xetla,)
|
||||
enable_xetla=enable_xetla,
|
||||
mixed_precision=mixed_precision)
|
||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||
|
||||
# enable tie_word_embeddings for MPT
|
||||
|
|
|
|||
Loading…
Reference in a new issue