disable lm_head opt for baichuan2-13b (#11905)

This commit is contained in:
Yina Chen 2024-08-23 10:39:47 +03:00 committed by GitHub
parent 4cf640c548
commit 23631cd357
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -405,9 +405,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
optimize_lm_head = ( optimize_lm_head = (
is_lm_head(name, model_config, out_features) is_lm_head(name, model_config, out_features)
and ( and (
(not os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "0") not os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "0"
or os.environ.get("IPEX_LLM_LOW_MEM", "0") == "1" )
and getattr(model_config, "model_type", "") in ["gptj", "llama", "qwen2"] and (
not (getattr(model_config, "model_type", "") == "baichuan" and
model.config.hidden_size == 5120) # except baichuan2-13B
) )
) )
with init_empty_weights(): with init_empty_weights():