optimize minicpm (#12496)

This commit is contained in:
Yishuo Wang 2024-12-04 17:14:16 +08:00 committed by GitHub
parent ae9c2154f4
commit a9e3f7f14c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 64 additions and 1 deletions

View file

@ -1032,8 +1032,9 @@ def _optimize_pre(model, qtype=None):
from ipex_llm.transformers.models.mllama import merge_qkv from ipex_llm.transformers.models.mllama import merge_qkv
model.apply(merge_qkv) model.apply(merge_qkv)
elif model.config.model_type == "minicpm": elif model.config.model_type == "minicpm":
from ipex_llm.transformers.models.minicpm import merge_qkv from ipex_llm.transformers.models.minicpm import merge_qkv, apply_residual_scale
model.apply(merge_qkv) model.apply(merge_qkv)
model.apply(apply_residual_scale)
elif model.config.model_type == "minicpm3": elif model.config.model_type == "minicpm3":
from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq
model.apply(pre_compute_inv_freq) model.apply(pre_compute_inv_freq)
@ -2101,9 +2102,11 @@ def _optimize_post(model, lightweight_bmm=False):
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.minicpm import minicpm_attention_forward from ipex_llm.transformers.models.minicpm import minicpm_attention_forward
from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper
from ipex_llm.transformers.models.minicpm import minicpm_decoder_layer_forward
convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward) convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward)
convert_forward(model, module.MiniCPMMLP, llama_mlp_forward) convert_forward(model, module.MiniCPMMLP, llama_mlp_forward)
convert_forward(model, module.MiniCPMRMSNorm, llama_rms_norm_forward) convert_forward(model, module.MiniCPMRMSNorm, llama_rms_norm_forward)
convert_forward(model, module.MiniCPMDecoderLayer, minicpm_decoder_layer_forward)
minicpm_model_forward = minicpm_model_forward_wrapper(module.MiniCPMModel.forward) minicpm_model_forward = minicpm_model_forward_wrapper(module.MiniCPMModel.forward)
convert_forward(model, module.MiniCPMModel, minicpm_model_forward) convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
elif model.config.model_type == "minicpm3": elif model.config.model_type == "minicpm3":

View file

@ -56,6 +56,17 @@ def merge_qkv(module: torch.nn.Module):
return merge_qkv_base(module, "MiniCPMAttention") return merge_qkv_base(module, "MiniCPMAttention")
def apply_residual_scale(module: torch.nn.Module):
if module.__class__.__name__ == "MiniCPMDecoderLayer":
scale = module.scale_depth / math.sqrt(module.num_hidden_layers)
module.self_attn.o_proj.weight.data *= scale
if module.self_attn.o_proj.bias is not None:
module.self_attn.o_proj.bias.weight.data *= scale
module.mlp.down_proj.weight.data *= scale
if module.mlp.down_proj.bias is not None:
module.mlp.down_proj.bias.weight.data *= scale
def minicpm_attention_forward( def minicpm_attention_forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -214,3 +225,52 @@ def minicpm_model_forward_wrapper(origin_forward):
) )
return minicpm_model_forward return minicpm_model_forward
def minicpm_decoder_layer_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)
# ipex-llm changes start
hidden_states = residual + hidden_states
# ipex-llm changes end
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
# ipex-llm changes start
hidden_states = residual + hidden_states
# ipex-llm changes end
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs