Support lm_head of minicpm-2b on NPU (#12019)
This commit is contained in:
parent
820f8a4554
commit
845e5dc89e
2 changed files with 105 additions and 4 deletions
|
|
@ -42,16 +42,34 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):
|
|||
from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
|
||||
model.apply(pre_compute_inv_freq)
|
||||
|
||||
# MiniCPM-V 2.6 and minicpm-2b must put lm_head on CPU now
|
||||
# MiniCPM-V 2.6 must put lm_head on CPU now
|
||||
cpu_lm_head = (
|
||||
(model.config.model_type == "minicpmv" and model.config.hidden_size == 3584 and
|
||||
model.config.vocab_size == 151666)
|
||||
or (
|
||||
model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40
|
||||
)
|
||||
or os.environ.get("IPEX_LLM_CPU_LM_HEAD", "0") != "0"
|
||||
)
|
||||
|
||||
# workaround for MiniCPM-2B
|
||||
if model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40:
|
||||
# 73440 is vocab_size of MiniCPM-1B
|
||||
new_linear_0 = torch.nn.Linear(0, 0, bias=False)
|
||||
new_weight_0 = torch.nn.Parameter(model.lm_head.weight[:73440, :], requires_grad=False)
|
||||
new_linear_0.weight = new_weight_0
|
||||
new_linear_0.in_features = new_weight_0.size(1)
|
||||
new_linear_0.out_features = new_weight_0.size(0)
|
||||
model.lm_head_0 = new_linear_0
|
||||
|
||||
new_linear_1 = torch.nn.Linear(0, 0, bias=False)
|
||||
import torch.nn.functional as F
|
||||
padded_weight = F.pad(model.lm_head.weight[73440:, :],
|
||||
(0, 0, 0, 73440*2 - model.config.vocab_size))
|
||||
new_weight_1 = torch.nn.Parameter(padded_weight, requires_grad=False)
|
||||
new_linear_1.weight = new_weight_1
|
||||
new_linear_1.in_features = new_weight_1.size(1)
|
||||
new_linear_1.out_features = new_weight_1.size(0)
|
||||
model.lm_head_1 = new_linear_1
|
||||
del model.lm_head
|
||||
|
||||
if model.config.model_type == "minicpmv" and hasattr(model, "llm"):
|
||||
# MiniCPM-V
|
||||
if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
|
||||
|
|
@ -201,6 +219,10 @@ def optimize_llm(
|
|||
prefill_runner=prefill_runner, decode_runner=decode_runner
|
||||
)
|
||||
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
|
||||
if model.config.num_hidden_layers == 40:
|
||||
# for minicpm-2b
|
||||
from ipex_llm.transformers.npu_models.minicpm_mp import minicpm_casullm_forward
|
||||
convert_forward(model, module.MiniCPMForCausalLM, minicpm_casullm_forward)
|
||||
elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32:
|
||||
# for Baichuan2-7B
|
||||
if intra_pp is None:
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ from transformers.cache_utils import Cache
|
|||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from ipex_llm.transformers.npu_models.mp_models_base import run_model
|
||||
from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
|
||||
class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||
|
|
@ -985,3 +987,80 @@ def gen_minicpm_fused_model_forward(prefill_runner, decode_runner):
|
|||
)
|
||||
|
||||
return minicpm_fused_model_forward
|
||||
|
||||
|
||||
def minicpm_casullm_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None \
|
||||
else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp,
|
||||
dim=0)
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i])
|
||||
for i in range(self.config.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
# ipex-llm change start
|
||||
logits1 = self.lm_head_0(hidden_states / (self.config.hidden_size /
|
||||
self.config.dim_model_base))
|
||||
logits2 = self.lm_head_1(hidden_states / (self.config.hidden_size /
|
||||
self.config.dim_model_base))
|
||||
logits = torch.cat((logits1, logits2[:, :, :49313]), dim=-1)
|
||||
# ipex-llm change end
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue