force lm_head optimization in any model if set environment variable (#11830)

This commit is contained in:
Yishuo Wang 2024-08-16 16:48:45 +08:00 committed by GitHub
parent 3b630fb9df
commit e966e85df8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -400,14 +400,14 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if is_linear and not isinstance(module, LowBitLinear): if is_linear and not isinstance(module, LowBitLinear):
in_features, out_features, mp_group = linear_args in_features, out_features, mp_group = linear_args
optimize_lm_head = False optimize_lm_head = (
if is_lm_head(name, model_config, out_features): is_lm_head(name, model_config, out_features)
model_type = getattr(model_config, "model_type", None) and (
if model_type in ["gptj", "llama", "qwen2"]: os.environ.get("IPEX_LLM_LAST_LM_HEAD", "0") == "1"
if os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) is not None: or os.environ.get("IPEX_LLM_LOW_MEM", "0") == "1"
optimize_lm_head = os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "1" and getattr(model_config, "model_type", "") in ["gptj", "llama", "qwen2"]
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None: )
optimize_lm_head = os.environ.get("IPEX_LLM_LOW_MEM", None) == "1" )
with init_empty_weights(): with init_empty_weights():
new_linear = None new_linear = None
is_gptq = is_gptq_linear(module) is_gptq = is_gptq_linear(module)