Support lm_head of minicpm-2b on NPU (#12019)

This commit is contained in:
binbin Deng 2024-09-05 16:19:22 +08:00 committed by GitHub
parent 820f8a4554
commit 845e5dc89e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 105 additions and 4 deletions

View file

@ -42,16 +42,34 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):
from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
model.apply(pre_compute_inv_freq)
# MiniCPM-V 2.6 and minicpm-2b must put lm_head on CPU now
# MiniCPM-V 2.6 must put lm_head on CPU now
cpu_lm_head = (
(model.config.model_type == "minicpmv" and model.config.hidden_size == 3584 and
model.config.vocab_size == 151666)
or (
model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40
)
or os.environ.get("IPEX_LLM_CPU_LM_HEAD", "0") != "0"
)
# workaround for MiniCPM-2B
if model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40:
# 73440 is vocab_size of MiniCPM-1B
new_linear_0 = torch.nn.Linear(0, 0, bias=False)
new_weight_0 = torch.nn.Parameter(model.lm_head.weight[:73440, :], requires_grad=False)
new_linear_0.weight = new_weight_0
new_linear_0.in_features = new_weight_0.size(1)
new_linear_0.out_features = new_weight_0.size(0)
model.lm_head_0 = new_linear_0
new_linear_1 = torch.nn.Linear(0, 0, bias=False)
import torch.nn.functional as F
padded_weight = F.pad(model.lm_head.weight[73440:, :],
(0, 0, 0, 73440*2 - model.config.vocab_size))
new_weight_1 = torch.nn.Parameter(padded_weight, requires_grad=False)
new_linear_1.weight = new_weight_1
new_linear_1.in_features = new_weight_1.size(1)
new_linear_1.out_features = new_weight_1.size(0)
model.lm_head_1 = new_linear_1
del model.lm_head
if model.config.model_type == "minicpmv" and hasattr(model, "llm"):
# MiniCPM-V
if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
@ -201,6 +219,10 @@ def optimize_llm(
prefill_runner=prefill_runner, decode_runner=decode_runner
)
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
if model.config.num_hidden_layers == 40:
# for minicpm-2b
from ipex_llm.transformers.npu_models.minicpm_mp import minicpm_casullm_forward
convert_forward(model, module.MiniCPMForCausalLM, minicpm_casullm_forward)
elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32:
# for Baichuan2-7B
if intra_pp is None:

View file

@ -50,6 +50,8 @@ from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
from ipex_llm.transformers.npu_models.mp_models_base import run_model
from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss
class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
@ -985,3 +987,80 @@ def gen_minicpm_fused_model_forward(prefill_runner, decode_runner):
)
return minicpm_fused_model_forward
def minicpm_casullm_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None \
else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp,
dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i])
for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
# ipex-llm change start
logits1 = self.lm_head_0(hidden_states / (self.config.hidden_size /
self.config.dim_model_base))
logits2 = self.lm_head_1(hidden_states / (self.config.hidden_size /
self.config.dim_model_base))
logits = torch.cat((logits1, logits2[:, :, :49313]), dim=-1)
# ipex-llm change end
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)