diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm.py b/python/llm/src/ipex_llm/transformers/models/minicpm.py index baa38d97..23a394e1 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm.py @@ -14,10 +14,15 @@ # limitations under the License. # # Some parts of this file is adapted from -# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/llama/modeling_llama.py +# https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16/blob/main/modeling_minicpm.py # which is licensed under Apache License 2.0: # -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 40efff47..8103df5c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -15,6 +15,7 @@ import torch +import importlib from intel_npu_acceleration_library.nn import QuantizedLinear @@ -95,8 +96,26 @@ def optimize_llm(model: torch.nn.Module): from ipex_llm.transformers.npu_models.qwen2 import qwen2_attention_forward from ipex_llm.transformers.npu_models.qwen2 import qwen2_mlp_forward from transformers.models.qwen2.modeling_qwen2 import Qwen2Model - from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention + from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2SdpaAttention from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP convert_forward(model, Qwen2Model, qwen2_model_forward) convert_forward(model, Qwen2Attention, qwen2_attention_forward) + convert_forward(model, Qwen2SdpaAttention, qwen2_attention_forward) convert_forward(model, Qwen2MLP, qwen2_mlp_forward) + + elif model.config.model_type == "minicpm": + from ipex_llm.transformers.npu_models.minicpm import merge_qkv + from ipex_llm.transformers.npu_models.minicpm import merge_mlp + from ipex_llm.transformers.npu_models.minicpm import padding_lm_head + model.apply(merge_qkv) + model.apply(merge_mlp) + model.apply(padding_lm_head) + + from ipex_llm.transformers.npu_models.minicpm import minicpm_model_causal_lm_forward + from ipex_llm.transformers.npu_models.minicpm import minicpm_attention_forward + from ipex_llm.transformers.npu_models.minicpm import minicpm_mlp_forward + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + convert_forward(model, module.MiniCPMForCausalLM, minicpm_model_causal_lm_forward) + convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward) + convert_forward(model, module.MiniCPMMLP, minicpm_mlp_forward) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/minicpm.py b/python/llm/src/ipex_llm/transformers/npu_models/minicpm.py new file mode 100644 index 00000000..d3ff6bcf --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/minicpm.py @@ -0,0 +1,256 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16/blob/main/modeling_minicpm.py +# which is licensed under Apache License 2.0: +# +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple, Union, List + +import torch +from torch.nn import CrossEntropyLoss + +from ipex_llm.transformers.npu_models.common import merge_linear +from ipex_llm.transformers.kv import DynamicNormalCache + +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast + + +def merge_qkv(module: torch.nn.Module): + if module.__class__.__name__ in ["MiniCPMAttention", "MiniCPMSdpaAttention"]: + qkv_proj = merge_linear([ + module.q_proj, + module.k_proj, + module.v_proj + ]) + module.qkv_proj = qkv_proj + del module.q_proj, module.k_proj, module.v_proj + + +def merge_mlp(module: torch.nn.Module): + if module.__class__.__name__ == "MiniCPMMLP": + gate_up_proj = merge_linear([ + module.gate_proj, + module.up_proj, + ]) + module.gate_up_proj = gate_up_proj + del module.gate_proj, module.up_proj + + +def padding_lm_head(module: torch.nn.Module): + if isinstance(module, torch.nn.Linear) and module.out_features == 122753: + new_weight = torch.empty(122816, module.in_features, + dtype=module.weight.dtype, device=module.weight.device) + new_weight[:122753, ...] = module.weight.data + module.out_features = 122816 + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + + +def minicpm_model_causal_lm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = ( + output_attentions if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # ipex-llm changes start: kv cache + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache and not isinstance(past_key_values, DynamicNormalCache): + past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) + # ipex-llm changes end + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states / + (self.config.hidden_size / self.config.dim_model_base)) + + # ipex-llm changes start: truncate logits to fix vocab size and remove logits.float() + logits = logits[..., :self.config.vocab_size] + # ipex-llm changes end + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, + n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + # cos = cos[position_ids].unsqueeze(unsqueeze_dim) + # sin = sin[position_ids].unsqueeze(unsqueeze_dim) + # q_embed = (q * cos) + (rotate_half(q) * sin) + # k_embed = (k * cos) + (rotate_half(k) * sin) + orig_dtype = k.dtype + cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + q_fp32 = q.to(dtype=torch.float32, device=q.device) + k_fp32 = k.to(dtype=torch.float32, device=k.device) + q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin) + k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin) + return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype) + + +def minicpm_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads], dim=1) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, None) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if query_states.size(2) == key_states.size(2): + # first token + from intel_npu_acceleration_library.functional import scaled_dot_product_attention + attn_output = scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=q_len > 1 and bsz == 1, + ) + attn_weights = None + else: + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + + +def minicpm_mlp_forward(self, x): + gate_up_proj = self.gate_up_proj(x) + gate_proj, up_proj = gate_up_proj.chunk(2, dim=-1) + down_proj = self.down_proj(self.act_fn(gate_proj) * up_proj) + return down_proj