fix npu lm_head cpu condition (#11976)

* fix

* fix

* fix

* fix stype

* fix style

* fix style
This commit is contained in:
Ruonan Wang 2024-08-30 02:11:26 -07:00 committed by GitHub
parent 60aa1a2c0f
commit 573c20bae6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 15 additions and 1 deletions

View file

@ -66,6 +66,7 @@ if __name__ == "__main__":
intra_pp=args.intra_pp, intra_pp=args.intra_pp,
inter_pp=args.inter_pp, inter_pp=args.inter_pp,
transpose_value_cache=not args.disable_transpose_value_cache, transpose_value_cache=not args.disable_transpose_value_cache,
modules_to_not_convert=['vpm', 'resampler']
) )
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

View file

@ -42,18 +42,31 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):
from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
model.apply(pre_compute_inv_freq) model.apply(pre_compute_inv_freq)
# MiniCPM-V 2.6 and minicpm-2b must put lm_head on CPU now
cpu_lm_head = (
(model.config.model_type == "minicpmv" and model.config.hidden_size == 3584 and
model.config.vocab_size == 151666)
or (
model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40
)
or os.environ.get("IPEX_LLM_CPU_LM_HEAD", "0") != "0"
)
if model.config.model_type == "minicpmv" and hasattr(model, "llm"): if model.config.model_type == "minicpmv" and hasattr(model, "llm"):
# MiniCPM-V # MiniCPM-V
if model.config.hidden_size == 2304 and model.config.vocab_size == 122753: if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
# MiniCPM-V 2
model.llm.config.model_type = "minicpm" model.llm.config.model_type = "minicpm"
elif model.config.hidden_size == 3584 and model.config.vocab_size == 151666: elif model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
# MiniCPM-V 2.6
model.llm.config.model_type = "qwen2" model.llm.config.model_type = "qwen2"
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256: elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
# MiniCPM-V 2.5
model.llm.config.model_type = "llama" model.llm.config.model_type = "llama"
model = model.llm model = model.llm
# lm_head to cpu optimization # lm_head to cpu optimization
if os.environ.get("IPEX_LLM_CPU_LM_HEAD", "0") != "0": if cpu_lm_head:
# disable the optimization by default # disable the optimization by default
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8 from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8
if qtype == "sym_int4_rtn": if qtype == "sym_int4_rtn":