optimize minicpm3 again (#12047)

This commit is contained in:
Yishuo Wang 2024-09-10 14:19:57 +08:00 committed by GitHub
parent f0061a9916
commit abc370728c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 18 additions and 9 deletions

View file

@ -1775,16 +1775,16 @@ def _optimize_post(model, lightweight_bmm=False):
elif model.config.model_type == "gemma2": elif model.config.model_type == "gemma2":
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.common import mlp_silu_forward
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
from ipex_llm.transformers.models.gemma2 import gemma2_model_forward from ipex_llm.transformers.models.gemma2 import gemma2_model_forward
from ipex_llm.transformers.models.gemma2 import gemma2_mlp_forward
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP
convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward) convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
convert_forward(model, Gemma2Attention, gemma2_attention_forward) convert_forward(model, Gemma2Attention, gemma2_attention_forward)
convert_forward(model, Gemma2Model, gemma2_model_forward) convert_forward(model, Gemma2Model, gemma2_model_forward)
convert_forward(model, Gemma2MLP, gemma2_mlp_forward) convert_forward(model, Gemma2MLP, mlp_silu_forward)
elif model.config.model_type == "Yi": elif model.config.model_type == "Yi":
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
@ -1968,7 +1968,9 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.common import rms_norm_forward from ipex_llm.transformers.models.common import rms_norm_forward
from ipex_llm.transformers.models.common import mlp_silu_forward
convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward) convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
from ipex_llm.transformers.models.minicpm3 import minicpm3_attention_forward from ipex_llm.transformers.models.minicpm3 import minicpm3_attention_forward
convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward) convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward)
elif model.config.model_type == "minicpmv": elif model.config.model_type == "minicpmv":

View file

@ -69,6 +69,16 @@ def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x)) return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
def mlp_silu_forward(self, x: torch.Tensor):
from ipex_llm.transformers.models.utils import SILU
return fuse_mlp_base(self, SILU, x)
def mlp_gelu_forward(self, x: torch.Tensor):
from ipex_llm.transformers.models.utils import GELU
return fuse_mlp_base(self, GELU, x)
def attention_softmax(attn_weights: torch.Tensor, training: bool): def attention_softmax(attn_weights: torch.Tensor, training: bool):
if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training: if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training:
import xe_addons import xe_addons

View file

@ -34,7 +34,7 @@
import torch import torch
from typing import Optional, Tuple from typing import Optional, Tuple
from ipex_llm.transformers.models.common import merge_qkv_base, fuse_mlp_base from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.utils import GELU from ipex_llm.transformers.models.utils import GELU
from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
@ -184,7 +184,3 @@ def gemma2_attention_forward(
attn_weights = None attn_weights = None
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def gemma2_mlp_forward(self, x: torch.Tensor):
return fuse_mlp_base(self, GELU, x)

View file

@ -92,8 +92,9 @@ def minicpm3_attention_forward(
xe_addons.rotary_half_inplaced(inv_freq, position_ids, xe_addons.rotary_half_inplaced(inv_freq, position_ids,
query_states[:, :, :, self.qk_nope_head_dim:], query_states[:, :, :, self.qk_nope_head_dim:],
key_states[:, :, :, self.qk_nope_head_dim:]) key_states[:, :, :, self.qk_nope_head_dim:])
query_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor if self.rotary_emb.scaling_factor != 1.0:
key_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor query_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor
key_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor
else: else:
invalidInputError(f"unknown rope method: {self.rotary_emb.__class__.__name__}") invalidInputError(f"unknown rope method: {self.rotary_emb.__class__.__name__}")
else: else: