LLM: add mixed precision for lm_head (#10795)
* add mixed_quantization * meet code review * update * fix style * meet review
This commit is contained in:
parent
8796401b08
commit
439c834ed3
2 changed files with 25 additions and 4 deletions
|
|
@ -92,6 +92,13 @@ if is_auto_awq_available():
|
||||||
from transformers.utils.quantization_config import AwqBackendPackingMethod
|
from transformers.utils.quantization_config import AwqBackendPackingMethod
|
||||||
|
|
||||||
|
|
||||||
|
def is_lm_head(name, model_config, out_features):
|
||||||
|
if name == "lm_head" or getattr(model_config, "vocab_size", None) == out_features:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_linear_module(module):
|
def is_linear_module(module):
|
||||||
|
|
||||||
in_features = None
|
in_features = None
|
||||||
|
|
@ -220,7 +227,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
cpu_embedding=False, prefix_name='',
|
cpu_embedding=False, prefix_name='',
|
||||||
imatrix_data=None, embedding_qtype=None,
|
imatrix_data=None, embedding_qtype=None,
|
||||||
model_config=None, torch_dtype=torch.float32,
|
model_config=None, torch_dtype=torch.float32,
|
||||||
enable_xetla=False):
|
enable_xetla=False,
|
||||||
|
mixed_precision=False):
|
||||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
||||||
FP16Linear, BF16Linear
|
FP16Linear, BF16Linear
|
||||||
from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
|
from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
|
||||||
|
|
@ -237,7 +245,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
if is_linear and not isinstance(module, LowBitLinear):
|
if is_linear and not isinstance(module, LowBitLinear):
|
||||||
in_features, out_features, mp_group = linear_args
|
in_features, out_features, mp_group = linear_args
|
||||||
optimize_lm_head = False
|
optimize_lm_head = False
|
||||||
if name == "lm_head":
|
if is_lm_head(name, model_config, out_features):
|
||||||
model_type = getattr(model_config, "model_type", None)
|
model_type = getattr(model_config, "model_type", None)
|
||||||
if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD",
|
if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD",
|
||||||
None) == "1":
|
None) == "1":
|
||||||
|
|
@ -291,6 +299,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
full_module_name,
|
full_module_name,
|
||||||
imatrix_data,
|
imatrix_data,
|
||||||
model_config)
|
model_config)
|
||||||
|
# mixed precison for lm_head
|
||||||
|
if mixed_precision and is_lm_head(name, model_config, out_features):
|
||||||
|
if cur_qtype in [ggml_tensor_qtype["sym_int4"],
|
||||||
|
ggml_tensor_qtype["asym_int4"]]:
|
||||||
|
cur_qtype = ggml_tensor_qtype["sym_int8"]
|
||||||
device = module.weight.data.device
|
device = module.weight.data.device
|
||||||
# Copy the weights
|
# Copy the weights
|
||||||
paramsLowBit = FP4Params(data=module.weight.data,
|
paramsLowBit = FP4Params(data=module.weight.data,
|
||||||
|
|
@ -409,6 +422,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
enable_xetla=enable_xetla,
|
enable_xetla=enable_xetla,
|
||||||
|
mixed_precision=mixed_precision
|
||||||
)
|
)
|
||||||
has_been_replaced = _flag or has_been_replaced
|
has_been_replaced = _flag or has_been_replaced
|
||||||
return model, has_been_replaced
|
return model, has_been_replaced
|
||||||
|
|
@ -684,7 +698,8 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
lightweight_bmm=False, torch_dtype="auto",
|
lightweight_bmm=False, torch_dtype="auto",
|
||||||
imatrix_data=None,
|
imatrix_data=None,
|
||||||
embedding_qtype=None,
|
embedding_qtype=None,
|
||||||
enable_xetla=False):
|
enable_xetla=False,
|
||||||
|
mixed_precision=False):
|
||||||
logger.info(f"Converting the current model to "
|
logger.info(f"Converting the current model to "
|
||||||
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
|
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
|
||||||
f"format......")
|
f"format......")
|
||||||
|
|
@ -709,6 +724,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
model_config=getattr(model, "config", None),
|
model_config=getattr(model, "config", None),
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
enable_xetla=enable_xetla,
|
enable_xetla=enable_xetla,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
)
|
)
|
||||||
if not has_been_replaced:
|
if not has_been_replaced:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
||||||
|
|
@ -140,6 +140,9 @@ class _BaseAutoModelClass:
|
||||||
specify the model hub. Default to be ``'huggingface'``.
|
specify the model hub. Default to be ``'huggingface'``.
|
||||||
:param embedding_qtype: str value, options are ``'q2_k'`` now. Default to be None.
|
:param embedding_qtype: str value, options are ``'q2_k'`` now. Default to be None.
|
||||||
Relevant low bit optimizations will be applied to nn.Embedding layer.
|
Relevant low bit optimizations will be applied to nn.Embedding layer.
|
||||||
|
:param mixed_precision: boolean value, Whether to use mixed precision quantization.
|
||||||
|
Default to be False. If set to True, we will use sym_int8 for lm_head when
|
||||||
|
load_in_low_bit is sym_int4 or asym_int4.
|
||||||
:return: a model instance
|
:return: a model instance
|
||||||
"""
|
"""
|
||||||
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
||||||
|
|
@ -394,6 +397,7 @@ class _BaseAutoModelClass:
|
||||||
quant_config = kwargs.pop("quantization_config", None)
|
quant_config = kwargs.pop("quantization_config", None)
|
||||||
imatrix_data = kwargs.pop("imatrix_data", None)
|
imatrix_data = kwargs.pop("imatrix_data", None)
|
||||||
embedding_qtype = kwargs.pop("embedding_qtype", None)
|
embedding_qtype = kwargs.pop("embedding_qtype", None)
|
||||||
|
mixed_precision = kwargs.pop("mixed_precision", False)
|
||||||
if embedding_qtype is not None:
|
if embedding_qtype is not None:
|
||||||
embedding_qtype = ggml_tensor_qtype[embedding_qtype]
|
embedding_qtype = ggml_tensor_qtype[embedding_qtype]
|
||||||
enable_xetla = kwargs.pop("enable_xetla", False)
|
enable_xetla = kwargs.pop("enable_xetla", False)
|
||||||
|
|
@ -463,7 +467,8 @@ class _BaseAutoModelClass:
|
||||||
torch_dtype=kwargs.get("torch_dtype", 'auto'),
|
torch_dtype=kwargs.get("torch_dtype", 'auto'),
|
||||||
imatrix_data=imatrix_data,
|
imatrix_data=imatrix_data,
|
||||||
embedding_qtype=embedding_qtype,
|
embedding_qtype=embedding_qtype,
|
||||||
enable_xetla=enable_xetla,)
|
enable_xetla=enable_xetla,
|
||||||
|
mixed_precision=mixed_precision)
|
||||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||||
|
|
||||||
# enable tie_word_embeddings for MPT
|
# enable tie_word_embeddings for MPT
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue