optimize minicpm (#12496)
This commit is contained in:
		
							parent
							
								
									ae9c2154f4
								
							
						
					
					
						commit
						a9e3f7f14c
					
				
					 2 changed files with 64 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -1032,8 +1032,9 @@ def _optimize_pre(model, qtype=None):
 | 
			
		|||
        from ipex_llm.transformers.models.mllama import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    elif model.config.model_type == "minicpm":
 | 
			
		||||
        from ipex_llm.transformers.models.minicpm import merge_qkv
 | 
			
		||||
        from ipex_llm.transformers.models.minicpm import merge_qkv, apply_residual_scale
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
        model.apply(apply_residual_scale)
 | 
			
		||||
    elif model.config.model_type == "minicpm3":
 | 
			
		||||
        from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq
 | 
			
		||||
        model.apply(pre_compute_inv_freq)
 | 
			
		||||
| 
						 | 
				
			
			@ -2101,9 +2102,11 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.minicpm import minicpm_attention_forward
 | 
			
		||||
        from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper
 | 
			
		||||
        from ipex_llm.transformers.models.minicpm import minicpm_decoder_layer_forward
 | 
			
		||||
        convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward)
 | 
			
		||||
        convert_forward(model, module.MiniCPMMLP, llama_mlp_forward)
 | 
			
		||||
        convert_forward(model, module.MiniCPMRMSNorm, llama_rms_norm_forward)
 | 
			
		||||
        convert_forward(model, module.MiniCPMDecoderLayer, minicpm_decoder_layer_forward)
 | 
			
		||||
        minicpm_model_forward = minicpm_model_forward_wrapper(module.MiniCPMModel.forward)
 | 
			
		||||
        convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
 | 
			
		||||
    elif model.config.model_type == "minicpm3":
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -56,6 +56,17 @@ def merge_qkv(module: torch.nn.Module):
 | 
			
		|||
    return merge_qkv_base(module, "MiniCPMAttention")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_residual_scale(module: torch.nn.Module):
 | 
			
		||||
    if module.__class__.__name__ == "MiniCPMDecoderLayer":
 | 
			
		||||
        scale = module.scale_depth / math.sqrt(module.num_hidden_layers)
 | 
			
		||||
        module.self_attn.o_proj.weight.data *= scale
 | 
			
		||||
        if module.self_attn.o_proj.bias is not None:
 | 
			
		||||
            module.self_attn.o_proj.bias.weight.data *= scale
 | 
			
		||||
        module.mlp.down_proj.weight.data *= scale
 | 
			
		||||
        if module.mlp.down_proj.bias is not None:
 | 
			
		||||
            module.mlp.down_proj.bias.weight.data *= scale
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def minicpm_attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			@ -214,3 +225,52 @@ def minicpm_model_forward_wrapper(origin_forward):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    return minicpm_model_forward
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def minicpm_decoder_layer_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
 | 
			
		||||
    output_attentions: Optional[bool] = False,
 | 
			
		||||
    use_cache: Optional[bool] = False,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 | 
			
		||||
    residual = hidden_states
 | 
			
		||||
    hidden_states = self.input_layernorm(hidden_states)
 | 
			
		||||
 | 
			
		||||
    # Self Attention
 | 
			
		||||
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
 | 
			
		||||
        hidden_states=hidden_states,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
        position_ids=position_ids,
 | 
			
		||||
        past_key_value=past_key_value,
 | 
			
		||||
        output_attentions=output_attentions,
 | 
			
		||||
        use_cache=use_cache,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # ipex-llm changes start
 | 
			
		||||
    hidden_states = residual + hidden_states
 | 
			
		||||
    # ipex-llm changes end
 | 
			
		||||
 | 
			
		||||
    # Fully Connected
 | 
			
		||||
    residual = hidden_states
 | 
			
		||||
    hidden_states = self.post_attention_layernorm(hidden_states)
 | 
			
		||||
 | 
			
		||||
    hidden_states = self.mlp(hidden_states)
 | 
			
		||||
 | 
			
		||||
    # ipex-llm changes start
 | 
			
		||||
    hidden_states = residual + hidden_states
 | 
			
		||||
    # ipex-llm changes end
 | 
			
		||||
 | 
			
		||||
    outputs = (hidden_states,)
 | 
			
		||||
 | 
			
		||||
    if output_attentions:
 | 
			
		||||
        outputs += (self_attn_weights,)
 | 
			
		||||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        outputs += (present_key_value,)
 | 
			
		||||
 | 
			
		||||
    return outputs
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue