optimize minicpm3 again (#12047)
This commit is contained in:
parent
f0061a9916
commit
abc370728c
4 changed files with 18 additions and 9 deletions
|
|
@ -1775,16 +1775,16 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
elif model.config.model_type == "gemma2":
|
elif model.config.model_type == "gemma2":
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
from ipex_llm.transformers.models.common import mlp_silu_forward
|
||||||
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
|
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
|
||||||
from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
|
from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
|
||||||
from ipex_llm.transformers.models.gemma2 import gemma2_model_forward
|
from ipex_llm.transformers.models.gemma2 import gemma2_model_forward
|
||||||
from ipex_llm.transformers.models.gemma2 import gemma2_mlp_forward
|
|
||||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention
|
||||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP
|
||||||
convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
|
convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
|
||||||
convert_forward(model, Gemma2Attention, gemma2_attention_forward)
|
convert_forward(model, Gemma2Attention, gemma2_attention_forward)
|
||||||
convert_forward(model, Gemma2Model, gemma2_model_forward)
|
convert_forward(model, Gemma2Model, gemma2_model_forward)
|
||||||
convert_forward(model, Gemma2MLP, gemma2_mlp_forward)
|
convert_forward(model, Gemma2MLP, mlp_silu_forward)
|
||||||
elif model.config.model_type == "Yi":
|
elif model.config.model_type == "Yi":
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
|
@ -1968,7 +1968,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.common import rms_norm_forward
|
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||||
|
from ipex_llm.transformers.models.common import mlp_silu_forward
|
||||||
convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
|
convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
|
||||||
|
convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
|
||||||
from ipex_llm.transformers.models.minicpm3 import minicpm3_attention_forward
|
from ipex_llm.transformers.models.minicpm3 import minicpm3_attention_forward
|
||||||
convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward)
|
convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward)
|
||||||
elif model.config.model_type == "minicpmv":
|
elif model.config.model_type == "minicpmv":
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,16 @@ def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
|
||||||
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
|
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
def mlp_silu_forward(self, x: torch.Tensor):
|
||||||
|
from ipex_llm.transformers.models.utils import SILU
|
||||||
|
return fuse_mlp_base(self, SILU, x)
|
||||||
|
|
||||||
|
|
||||||
|
def mlp_gelu_forward(self, x: torch.Tensor):
|
||||||
|
from ipex_llm.transformers.models.utils import GELU
|
||||||
|
return fuse_mlp_base(self, GELU, x)
|
||||||
|
|
||||||
|
|
||||||
def attention_softmax(attn_weights: torch.Tensor, training: bool):
|
def attention_softmax(attn_weights: torch.Tensor, training: bool):
|
||||||
if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training:
|
if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training:
|
||||||
import xe_addons
|
import xe_addons
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from ipex_llm.transformers.models.common import merge_qkv_base, fuse_mlp_base
|
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||||
from ipex_llm.transformers.models.utils import GELU
|
from ipex_llm.transformers.models.utils import GELU
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
|
|
@ -184,7 +184,3 @@ def gemma2_attention_forward(
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
def gemma2_mlp_forward(self, x: torch.Tensor):
|
|
||||||
return fuse_mlp_base(self, GELU, x)
|
|
||||||
|
|
|
||||||
|
|
@ -92,8 +92,9 @@ def minicpm3_attention_forward(
|
||||||
xe_addons.rotary_half_inplaced(inv_freq, position_ids,
|
xe_addons.rotary_half_inplaced(inv_freq, position_ids,
|
||||||
query_states[:, :, :, self.qk_nope_head_dim:],
|
query_states[:, :, :, self.qk_nope_head_dim:],
|
||||||
key_states[:, :, :, self.qk_nope_head_dim:])
|
key_states[:, :, :, self.qk_nope_head_dim:])
|
||||||
query_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor
|
if self.rotary_emb.scaling_factor != 1.0:
|
||||||
key_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor
|
query_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor
|
||||||
|
key_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor
|
||||||
else:
|
else:
|
||||||
invalidInputError(f"unknown rope method: {self.rotary_emb.__class__.__name__}")
|
invalidInputError(f"unknown rope method: {self.rotary_emb.__class__.__name__}")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue