optimize minicpm (#12496)
This commit is contained in:
parent
ae9c2154f4
commit
a9e3f7f14c
2 changed files with 64 additions and 1 deletions
|
|
@ -1032,8 +1032,9 @@ def _optimize_pre(model, qtype=None):
|
|||
from ipex_llm.transformers.models.mllama import merge_qkv
|
||||
model.apply(merge_qkv)
|
||||
elif model.config.model_type == "minicpm":
|
||||
from ipex_llm.transformers.models.minicpm import merge_qkv
|
||||
from ipex_llm.transformers.models.minicpm import merge_qkv, apply_residual_scale
|
||||
model.apply(merge_qkv)
|
||||
model.apply(apply_residual_scale)
|
||||
elif model.config.model_type == "minicpm3":
|
||||
from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq
|
||||
model.apply(pre_compute_inv_freq)
|
||||
|
|
@ -2101,9 +2102,11 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.minicpm import minicpm_attention_forward
|
||||
from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper
|
||||
from ipex_llm.transformers.models.minicpm import minicpm_decoder_layer_forward
|
||||
convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward)
|
||||
convert_forward(model, module.MiniCPMMLP, llama_mlp_forward)
|
||||
convert_forward(model, module.MiniCPMRMSNorm, llama_rms_norm_forward)
|
||||
convert_forward(model, module.MiniCPMDecoderLayer, minicpm_decoder_layer_forward)
|
||||
minicpm_model_forward = minicpm_model_forward_wrapper(module.MiniCPMModel.forward)
|
||||
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
|
||||
elif model.config.model_type == "minicpm3":
|
||||
|
|
|
|||
|
|
@ -56,6 +56,17 @@ def merge_qkv(module: torch.nn.Module):
|
|||
return merge_qkv_base(module, "MiniCPMAttention")
|
||||
|
||||
|
||||
def apply_residual_scale(module: torch.nn.Module):
|
||||
if module.__class__.__name__ == "MiniCPMDecoderLayer":
|
||||
scale = module.scale_depth / math.sqrt(module.num_hidden_layers)
|
||||
module.self_attn.o_proj.weight.data *= scale
|
||||
if module.self_attn.o_proj.bias is not None:
|
||||
module.self_attn.o_proj.bias.weight.data *= scale
|
||||
module.mlp.down_proj.weight.data *= scale
|
||||
if module.mlp.down_proj.bias is not None:
|
||||
module.mlp.down_proj.bias.weight.data *= scale
|
||||
|
||||
|
||||
def minicpm_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
@ -214,3 +225,52 @@ def minicpm_model_forward_wrapper(origin_forward):
|
|||
)
|
||||
|
||||
return minicpm_model_forward
|
||||
|
||||
|
||||
def minicpm_decoder_layer_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# ipex-llm changes start
|
||||
hidden_states = residual + hidden_states
|
||||
# ipex-llm changes end
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
# ipex-llm changes start
|
||||
hidden_states = residual + hidden_states
|
||||
# ipex-llm changes end
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
|
|
|||
Loading…
Reference in a new issue