Update lm_head optimization for Qwen2 7B (#12090)

This commit is contained in:
Yuwen Hu 2024-09-18 17:02:02 +08:00 committed by GitHub
parent ee33b93464
commit f7fb3c896c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -89,10 +89,10 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj
model.apply(split_mlp_down_proj) model.apply(split_mlp_down_proj)
# for Qwen2-7B-Insturct, divide lm_head into 7 parts # for Qwen2-7B-Insturct, divide lm_head into 14 parts
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \ if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
not cpu_lm_head: not cpu_lm_head:
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=7, new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=14,
bias=model.lm_head.bias) bias=model.lm_head.bias)
del model.lm_head del model.lm_head
model.lm_head = new_lm_head model.lm_head = new_lm_head
@ -192,7 +192,7 @@ def optimize_llm(
from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward
convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward) convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)
# for Qwen2-7B-Insturct, divide lm_head into 7 parts # for Qwen2-7B-Insturct, divide lm_head into 14 parts
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \ if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
isinstance(model.lm_head, SlicedLMHead): isinstance(model.lm_head, SlicedLMHead):
model.lm_head.get_fused_lm_head() model.lm_head.get_fused_lm_head()