From 23631cd357e0f8f78e00991a06c777598ff4176f Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Fri, 23 Aug 2024 10:39:47 +0300 Subject: [PATCH] disable lm_head opt for baichuan2-13b (#11905) --- python/llm/src/ipex_llm/transformers/convert.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 39e1edcb..3057c6f6 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -405,9 +405,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, optimize_lm_head = ( is_lm_head(name, model_config, out_features) and ( - (not os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "0") - or os.environ.get("IPEX_LLM_LOW_MEM", "0") == "1" - and getattr(model_config, "model_type", "") in ["gptj", "llama", "qwen2"] + not os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "0" + ) + and ( + not (getattr(model_config, "model_type", "") == "baichuan" and + model.config.hidden_size == 5120) # except baichuan2-13B ) ) with init_empty_weights():