force lm_head optimization in any model if set environment variable (#11830)
This commit is contained in:
parent
3b630fb9df
commit
e966e85df8
1 changed files with 8 additions and 8 deletions
|
|
@ -400,14 +400,14 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
|
|
||||||
if is_linear and not isinstance(module, LowBitLinear):
|
if is_linear and not isinstance(module, LowBitLinear):
|
||||||
in_features, out_features, mp_group = linear_args
|
in_features, out_features, mp_group = linear_args
|
||||||
optimize_lm_head = False
|
optimize_lm_head = (
|
||||||
if is_lm_head(name, model_config, out_features):
|
is_lm_head(name, model_config, out_features)
|
||||||
model_type = getattr(model_config, "model_type", None)
|
and (
|
||||||
if model_type in ["gptj", "llama", "qwen2"]:
|
os.environ.get("IPEX_LLM_LAST_LM_HEAD", "0") == "1"
|
||||||
if os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) is not None:
|
or os.environ.get("IPEX_LLM_LOW_MEM", "0") == "1"
|
||||||
optimize_lm_head = os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "1"
|
and getattr(model_config, "model_type", "") in ["gptj", "llama", "qwen2"]
|
||||||
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
|
)
|
||||||
optimize_lm_head = os.environ.get("IPEX_LLM_LOW_MEM", None) == "1"
|
)
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
new_linear = None
|
new_linear = None
|
||||||
is_gptq = is_gptq_linear(module)
|
is_gptq = is_gptq_linear(module)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue