add minicpm 1B/2B npu support (#11507)

This commit is contained in:
Yishuo Wang 2024-07-04 16:31:04 +08:00 committed by GitHub
parent bb0a84044b
commit 1a8bab172e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 283 additions and 3 deletions

View file

@ -14,10 +14,15 @@
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/llama/modeling_llama.py
# https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16/blob/main/modeling_minicpm.py
# which is licensed under Apache License 2.0:
#
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View file

@ -15,6 +15,7 @@
import torch
import importlib
from intel_npu_acceleration_library.nn import QuantizedLinear
@ -95,8 +96,26 @@ def optimize_llm(model: torch.nn.Module):
from ipex_llm.transformers.npu_models.qwen2 import qwen2_attention_forward
from ipex_llm.transformers.npu_models.qwen2 import qwen2_mlp_forward
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2SdpaAttention
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
convert_forward(model, Qwen2Model, qwen2_model_forward)
convert_forward(model, Qwen2Attention, qwen2_attention_forward)
convert_forward(model, Qwen2SdpaAttention, qwen2_attention_forward)
convert_forward(model, Qwen2MLP, qwen2_mlp_forward)
elif model.config.model_type == "minicpm":
from ipex_llm.transformers.npu_models.minicpm import merge_qkv
from ipex_llm.transformers.npu_models.minicpm import merge_mlp
from ipex_llm.transformers.npu_models.minicpm import padding_lm_head
model.apply(merge_qkv)
model.apply(merge_mlp)
model.apply(padding_lm_head)
from ipex_llm.transformers.npu_models.minicpm import minicpm_model_causal_lm_forward
from ipex_llm.transformers.npu_models.minicpm import minicpm_attention_forward
from ipex_llm.transformers.npu_models.minicpm import minicpm_mlp_forward
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
convert_forward(model, module.MiniCPMForCausalLM, minicpm_model_causal_lm_forward)
convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward)
convert_forward(model, module.MiniCPMMLP, minicpm_mlp_forward)

View file

@ -0,0 +1,256 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16/blob/main/modeling_minicpm.py
# which is licensed under Apache License 2.0:
#
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple, Union, List
import torch
from torch.nn import CrossEntropyLoss
from ipex_llm.transformers.npu_models.common import merge_linear
from ipex_llm.transformers.kv import DynamicNormalCache
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
def merge_qkv(module: torch.nn.Module):
if module.__class__.__name__ in ["MiniCPMAttention", "MiniCPMSdpaAttention"]:
qkv_proj = merge_linear([
module.q_proj,
module.k_proj,
module.v_proj
])
module.qkv_proj = qkv_proj
del module.q_proj, module.k_proj, module.v_proj
def merge_mlp(module: torch.nn.Module):
if module.__class__.__name__ == "MiniCPMMLP":
gate_up_proj = merge_linear([
module.gate_proj,
module.up_proj,
])
module.gate_up_proj = gate_up_proj
del module.gate_proj, module.up_proj
def padding_lm_head(module: torch.nn.Module):
if isinstance(module, torch.nn.Linear) and module.out_features == 122753:
new_weight = torch.empty(122816, module.in_features,
dtype=module.weight.dtype, device=module.weight.device)
new_weight[:122753, ...] = module.weight.data
module.out_features = 122816
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
def minicpm_model_causal_lm_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# ipex-llm changes start: kv cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache and not isinstance(past_key_values, DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
# ipex-llm changes end
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states /
(self.config.hidden_size / self.config.dim_model_base))
# ipex-llm changes start: truncate logits to fix vocab size and remove logits.float()
logits = logits[..., :self.config.vocab_size]
# ipex-llm changes end
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
# cos = cos[position_ids].unsqueeze(unsqueeze_dim)
# sin = sin[position_ids].unsqueeze(unsqueeze_dim)
# q_embed = (q * cos) + (rotate_half(q) * sin)
# k_embed = (k * cos) + (rotate_half(k) * sin)
orig_dtype = k.dtype
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
q_fp32 = q.to(dtype=torch.float32, device=q.device)
k_fp32 = k.to(dtype=torch.float32, device=k.device)
q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin)
k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)
def minicpm_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=1)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if query_states.size(2) == key_states.size(2):
# first token
from intel_npu_acceleration_library.functional import scaled_dot_product_attention
attn_output = scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=q_len > 1 and bsz == 1,
)
attn_weights = None
else:
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def minicpm_mlp_forward(self, x):
gate_up_proj = self.gate_up_proj(x)
gate_proj, up_proj = gate_up_proj.chunk(2, dim=-1)
down_proj = self.down_proj(self.act_fn(gate_proj) * up_proj)
return down_proj