optimize minicpm3 again (#12047)
This commit is contained in:
parent
f0061a9916
commit
abc370728c
4 changed files with 18 additions and 9 deletions
|
|
@ -1775,16 +1775,16 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
elif model.config.model_type == "gemma2":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.common import mlp_silu_forward
|
||||
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
|
||||
from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
|
||||
from ipex_llm.transformers.models.gemma2 import gemma2_model_forward
|
||||
from ipex_llm.transformers.models.gemma2 import gemma2_mlp_forward
|
||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention
|
||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP
|
||||
convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
|
||||
convert_forward(model, Gemma2Attention, gemma2_attention_forward)
|
||||
convert_forward(model, Gemma2Model, gemma2_model_forward)
|
||||
convert_forward(model, Gemma2MLP, gemma2_mlp_forward)
|
||||
convert_forward(model, Gemma2MLP, mlp_silu_forward)
|
||||
elif model.config.model_type == "Yi":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
@ -1968,7 +1968,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||
from ipex_llm.transformers.models.common import mlp_silu_forward
|
||||
convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
|
||||
convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
|
||||
from ipex_llm.transformers.models.minicpm3 import minicpm3_attention_forward
|
||||
convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward)
|
||||
elif model.config.model_type == "minicpmv":
|
||||
|
|
|
|||
|
|
@ -69,6 +69,16 @@ def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
|
|||
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
|
||||
|
||||
|
||||
def mlp_silu_forward(self, x: torch.Tensor):
|
||||
from ipex_llm.transformers.models.utils import SILU
|
||||
return fuse_mlp_base(self, SILU, x)
|
||||
|
||||
|
||||
def mlp_gelu_forward(self, x: torch.Tensor):
|
||||
from ipex_llm.transformers.models.utils import GELU
|
||||
return fuse_mlp_base(self, GELU, x)
|
||||
|
||||
|
||||
def attention_softmax(attn_weights: torch.Tensor, training: bool):
|
||||
if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training:
|
||||
import xe_addons
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@
|
|||
import torch
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base, fuse_mlp_base
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||
from ipex_llm.transformers.models.utils import GELU
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal
|
||||
from transformers.cache_utils import Cache
|
||||
|
|
@ -184,7 +184,3 @@ def gemma2_attention_forward(
|
|||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def gemma2_mlp_forward(self, x: torch.Tensor):
|
||||
return fuse_mlp_base(self, GELU, x)
|
||||
|
|
|
|||
|
|
@ -92,8 +92,9 @@ def minicpm3_attention_forward(
|
|||
xe_addons.rotary_half_inplaced(inv_freq, position_ids,
|
||||
query_states[:, :, :, self.qk_nope_head_dim:],
|
||||
key_states[:, :, :, self.qk_nope_head_dim:])
|
||||
query_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor
|
||||
key_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor
|
||||
if self.rotary_emb.scaling_factor != 1.0:
|
||||
query_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor
|
||||
key_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor
|
||||
else:
|
||||
invalidInputError(f"unknown rope method: {self.rotary_emb.__class__.__name__}")
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in a new issue