diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 39e1edcb..3057c6f6 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -405,9 +405,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, optimize_lm_head = ( is_lm_head(name, model_config, out_features) and ( - (not os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "0") - or os.environ.get("IPEX_LLM_LOW_MEM", "0") == "1" - and getattr(model_config, "model_type", "") in ["gptj", "llama", "qwen2"] + not os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "0" + ) + and ( + not (getattr(model_config, "model_type", "") == "baichuan" and + model.config.hidden_size == 5120) # except baichuan2-13B ) ) with init_empty_weights():