optimize minicpm (#12496)
This commit is contained in:
parent
ae9c2154f4
commit
a9e3f7f14c
2 changed files with 64 additions and 1 deletions
|
|
@ -1032,8 +1032,9 @@ def _optimize_pre(model, qtype=None):
|
||||||
from ipex_llm.transformers.models.mllama import merge_qkv
|
from ipex_llm.transformers.models.mllama import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
elif model.config.model_type == "minicpm":
|
elif model.config.model_type == "minicpm":
|
||||||
from ipex_llm.transformers.models.minicpm import merge_qkv
|
from ipex_llm.transformers.models.minicpm import merge_qkv, apply_residual_scale
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
|
model.apply(apply_residual_scale)
|
||||||
elif model.config.model_type == "minicpm3":
|
elif model.config.model_type == "minicpm3":
|
||||||
from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq
|
from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq
|
||||||
model.apply(pre_compute_inv_freq)
|
model.apply(pre_compute_inv_freq)
|
||||||
|
|
@ -2101,9 +2102,11 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.minicpm import minicpm_attention_forward
|
from ipex_llm.transformers.models.minicpm import minicpm_attention_forward
|
||||||
from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper
|
from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper
|
||||||
|
from ipex_llm.transformers.models.minicpm import minicpm_decoder_layer_forward
|
||||||
convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward)
|
convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward)
|
||||||
convert_forward(model, module.MiniCPMMLP, llama_mlp_forward)
|
convert_forward(model, module.MiniCPMMLP, llama_mlp_forward)
|
||||||
convert_forward(model, module.MiniCPMRMSNorm, llama_rms_norm_forward)
|
convert_forward(model, module.MiniCPMRMSNorm, llama_rms_norm_forward)
|
||||||
|
convert_forward(model, module.MiniCPMDecoderLayer, minicpm_decoder_layer_forward)
|
||||||
minicpm_model_forward = minicpm_model_forward_wrapper(module.MiniCPMModel.forward)
|
minicpm_model_forward = minicpm_model_forward_wrapper(module.MiniCPMModel.forward)
|
||||||
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
|
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
|
||||||
elif model.config.model_type == "minicpm3":
|
elif model.config.model_type == "minicpm3":
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,17 @@ def merge_qkv(module: torch.nn.Module):
|
||||||
return merge_qkv_base(module, "MiniCPMAttention")
|
return merge_qkv_base(module, "MiniCPMAttention")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_residual_scale(module: torch.nn.Module):
|
||||||
|
if module.__class__.__name__ == "MiniCPMDecoderLayer":
|
||||||
|
scale = module.scale_depth / math.sqrt(module.num_hidden_layers)
|
||||||
|
module.self_attn.o_proj.weight.data *= scale
|
||||||
|
if module.self_attn.o_proj.bias is not None:
|
||||||
|
module.self_attn.o_proj.bias.weight.data *= scale
|
||||||
|
module.mlp.down_proj.weight.data *= scale
|
||||||
|
if module.mlp.down_proj.bias is not None:
|
||||||
|
module.mlp.down_proj.bias.weight.data *= scale
|
||||||
|
|
||||||
|
|
||||||
def minicpm_attention_forward(
|
def minicpm_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
@ -214,3 +225,52 @@ def minicpm_model_forward_wrapper(origin_forward):
|
||||||
)
|
)
|
||||||
|
|
||||||
return minicpm_model_forward
|
return minicpm_model_forward
|
||||||
|
|
||||||
|
|
||||||
|
def minicpm_decoder_layer_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ipex-llm changes start
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
# ipex-llm changes end
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
# ipex-llm changes start
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
# ipex-llm changes end
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue