diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 0c70bf63..5386c426 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -42,16 +42,34 @@ def optimize_llm_pre(model: torch.nn.Module, qtype): from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq model.apply(pre_compute_inv_freq) - # MiniCPM-V 2.6 and minicpm-2b must put lm_head on CPU now + # MiniCPM-V 2.6 must put lm_head on CPU now cpu_lm_head = ( (model.config.model_type == "minicpmv" and model.config.hidden_size == 3584 and model.config.vocab_size == 151666) - or ( - model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40 - ) or os.environ.get("IPEX_LLM_CPU_LM_HEAD", "0") != "0" ) + # workaround for MiniCPM-2B + if model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40: + # 73440 is vocab_size of MiniCPM-1B + new_linear_0 = torch.nn.Linear(0, 0, bias=False) + new_weight_0 = torch.nn.Parameter(model.lm_head.weight[:73440, :], requires_grad=False) + new_linear_0.weight = new_weight_0 + new_linear_0.in_features = new_weight_0.size(1) + new_linear_0.out_features = new_weight_0.size(0) + model.lm_head_0 = new_linear_0 + + new_linear_1 = torch.nn.Linear(0, 0, bias=False) + import torch.nn.functional as F + padded_weight = F.pad(model.lm_head.weight[73440:, :], + (0, 0, 0, 73440*2 - model.config.vocab_size)) + new_weight_1 = torch.nn.Parameter(padded_weight, requires_grad=False) + new_linear_1.weight = new_weight_1 + new_linear_1.in_features = new_weight_1.size(1) + new_linear_1.out_features = new_weight_1.size(0) + model.lm_head_1 = new_linear_1 + del model.lm_head + if model.config.model_type == "minicpmv" and hasattr(model, "llm"): # MiniCPM-V if model.config.hidden_size == 2304 and model.config.vocab_size == 122753: @@ -201,6 +219,10 @@ def optimize_llm( prefill_runner=prefill_runner, decode_runner=decode_runner ) convert_forward(model, module.MiniCPMModel, minicpm_model_forward) + if model.config.num_hidden_layers == 40: + # for minicpm-2b + from ipex_llm.transformers.npu_models.minicpm_mp import minicpm_casullm_forward + convert_forward(model, module.MiniCPMForCausalLM, minicpm_casullm_forward) elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32: # for Baichuan2-7B if intra_pp is None: diff --git a/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py index 80abb8ba..48271271 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py @@ -50,6 +50,8 @@ from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast from ipex_llm.transformers.npu_models.mp_models_base import run_model from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory +from transformers.modeling_outputs import CausalLMOutputWithPast +from torch.nn import CrossEntropyLoss class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): @@ -985,3 +987,80 @@ def gen_minicpm_fused_model_forward(prefill_runner, decode_runner): ) return minicpm_fused_model_forward + + +def minicpm_casullm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, + dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) + for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + # ipex-llm change start + logits1 = self.lm_head_0(hidden_states / (self.config.hidden_size / + self.config.dim_model_base)) + logits2 = self.lm_head_1(hidden_states / (self.config.hidden_size / + self.config.dim_model_base)) + logits = torch.cat((logits1, logits2[:, :, :49313]), dim=-1) + # ipex-llm change end + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + )