Update lm_head optimization for Qwen2 7B (#12090)
This commit is contained in:
parent
ee33b93464
commit
f7fb3c896c
1 changed files with 3 additions and 3 deletions
|
|
@ -89,10 +89,10 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):
|
|||
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj
|
||||
model.apply(split_mlp_down_proj)
|
||||
|
||||
# for Qwen2-7B-Insturct, divide lm_head into 7 parts
|
||||
# for Qwen2-7B-Insturct, divide lm_head into 14 parts
|
||||
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
|
||||
not cpu_lm_head:
|
||||
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=7,
|
||||
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=14,
|
||||
bias=model.lm_head.bias)
|
||||
del model.lm_head
|
||||
model.lm_head = new_lm_head
|
||||
|
|
@ -192,7 +192,7 @@ def optimize_llm(
|
|||
from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward
|
||||
convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)
|
||||
|
||||
# for Qwen2-7B-Insturct, divide lm_head into 7 parts
|
||||
# for Qwen2-7B-Insturct, divide lm_head into 14 parts
|
||||
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
|
||||
isinstance(model.lm_head, SlicedLMHead):
|
||||
model.lm_head.get_fused_lm_head()
|
||||
|
|
|
|||
Loading…
Reference in a new issue