diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 2ce495f0..65944f05 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -400,14 +400,14 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if is_linear and not isinstance(module, LowBitLinear): in_features, out_features, mp_group = linear_args - optimize_lm_head = False - if is_lm_head(name, model_config, out_features): - model_type = getattr(model_config, "model_type", None) - if model_type in ["gptj", "llama", "qwen2"]: - if os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) is not None: - optimize_lm_head = os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "1" - elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None: - optimize_lm_head = os.environ.get("IPEX_LLM_LOW_MEM", None) == "1" + optimize_lm_head = ( + is_lm_head(name, model_config, out_features) + and ( + os.environ.get("IPEX_LLM_LAST_LM_HEAD", "0") == "1" + or os.environ.get("IPEX_LLM_LOW_MEM", "0") == "1" + and getattr(model_config, "model_type", "") in ["gptj", "llama", "qwen2"] + ) + ) with init_empty_weights(): new_linear = None is_gptq = is_gptq_linear(module)