fix npu lm_head cpu condition (#11976)
* fix * fix * fix * fix stype * fix style * fix style
This commit is contained in:
parent
60aa1a2c0f
commit
573c20bae6
2 changed files with 15 additions and 1 deletions
|
|
@ -66,6 +66,7 @@ if __name__ == "__main__":
|
||||||
intra_pp=args.intra_pp,
|
intra_pp=args.intra_pp,
|
||||||
inter_pp=args.inter_pp,
|
inter_pp=args.inter_pp,
|
||||||
transpose_value_cache=not args.disable_transpose_value_cache,
|
transpose_value_cache=not args.disable_transpose_value_cache,
|
||||||
|
modules_to_not_convert=['vpm', 'resampler']
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,18 +42,31 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):
|
||||||
from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
|
from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
|
||||||
model.apply(pre_compute_inv_freq)
|
model.apply(pre_compute_inv_freq)
|
||||||
|
|
||||||
|
# MiniCPM-V 2.6 and minicpm-2b must put lm_head on CPU now
|
||||||
|
cpu_lm_head = (
|
||||||
|
(model.config.model_type == "minicpmv" and model.config.hidden_size == 3584 and
|
||||||
|
model.config.vocab_size == 151666)
|
||||||
|
or (
|
||||||
|
model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40
|
||||||
|
)
|
||||||
|
or os.environ.get("IPEX_LLM_CPU_LM_HEAD", "0") != "0"
|
||||||
|
)
|
||||||
|
|
||||||
if model.config.model_type == "minicpmv" and hasattr(model, "llm"):
|
if model.config.model_type == "minicpmv" and hasattr(model, "llm"):
|
||||||
# MiniCPM-V
|
# MiniCPM-V
|
||||||
if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
|
if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
|
||||||
|
# MiniCPM-V 2
|
||||||
model.llm.config.model_type = "minicpm"
|
model.llm.config.model_type = "minicpm"
|
||||||
elif model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
elif model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
||||||
|
# MiniCPM-V 2.6
|
||||||
model.llm.config.model_type = "qwen2"
|
model.llm.config.model_type = "qwen2"
|
||||||
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
|
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
|
||||||
|
# MiniCPM-V 2.5
|
||||||
model.llm.config.model_type = "llama"
|
model.llm.config.model_type = "llama"
|
||||||
model = model.llm
|
model = model.llm
|
||||||
|
|
||||||
# lm_head to cpu optimization
|
# lm_head to cpu optimization
|
||||||
if os.environ.get("IPEX_LLM_CPU_LM_HEAD", "0") != "0":
|
if cpu_lm_head:
|
||||||
# disable the optimization by default
|
# disable the optimization by default
|
||||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8
|
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8
|
||||||
if qtype == "sym_int4_rtn":
|
if qtype == "sym_int4_rtn":
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue