Update lm_head optimization for Qwen2 7B (#12090)
This commit is contained in:
parent
ee33b93464
commit
f7fb3c896c
1 changed files with 3 additions and 3 deletions
|
|
@ -89,10 +89,10 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):
|
||||||
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj
|
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj
|
||||||
model.apply(split_mlp_down_proj)
|
model.apply(split_mlp_down_proj)
|
||||||
|
|
||||||
# for Qwen2-7B-Insturct, divide lm_head into 7 parts
|
# for Qwen2-7B-Insturct, divide lm_head into 14 parts
|
||||||
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
|
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
|
||||||
not cpu_lm_head:
|
not cpu_lm_head:
|
||||||
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=7,
|
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=14,
|
||||||
bias=model.lm_head.bias)
|
bias=model.lm_head.bias)
|
||||||
del model.lm_head
|
del model.lm_head
|
||||||
model.lm_head = new_lm_head
|
model.lm_head = new_lm_head
|
||||||
|
|
@ -192,7 +192,7 @@ def optimize_llm(
|
||||||
from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward
|
from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward
|
||||||
convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)
|
convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)
|
||||||
|
|
||||||
# for Qwen2-7B-Insturct, divide lm_head into 7 parts
|
# for Qwen2-7B-Insturct, divide lm_head into 14 parts
|
||||||
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
|
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
|
||||||
isinstance(model.lm_head, SlicedLMHead):
|
isinstance(model.lm_head, SlicedLMHead):
|
||||||
model.lm_head.get_fused_lm_head()
|
model.lm_head.get_fused_lm_head()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue