From abc370728c0a35787ecb08edd3eadb3a8ca54d7a Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 10 Sep 2024 14:19:57 +0800 Subject: [PATCH] optimize minicpm3 again (#12047) --- python/llm/src/ipex_llm/transformers/convert.py | 6 ++++-- python/llm/src/ipex_llm/transformers/models/common.py | 10 ++++++++++ python/llm/src/ipex_llm/transformers/models/gemma2.py | 6 +----- .../llm/src/ipex_llm/transformers/models/minicpm3.py | 5 +++-- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 90eb2eae..d3fd44c3 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1775,16 +1775,16 @@ def _optimize_post(model, lightweight_bmm=False): elif model.config.model_type == "gemma2": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) + from ipex_llm.transformers.models.common import mlp_silu_forward from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward from ipex_llm.transformers.models.gemma2 import gemma2_model_forward - from ipex_llm.transformers.models.gemma2 import gemma2_mlp_forward from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward) convert_forward(model, Gemma2Attention, gemma2_attention_forward) convert_forward(model, Gemma2Model, gemma2_model_forward) - convert_forward(model, Gemma2MLP, gemma2_mlp_forward) + convert_forward(model, Gemma2MLP, mlp_silu_forward) elif model.config.model_type == "Yi": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) @@ -1968,7 +1968,9 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.common import rms_norm_forward + from ipex_llm.transformers.models.common import mlp_silu_forward convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward) + convert_forward(model, module.MiniCPMMLP, mlp_silu_forward) from ipex_llm.transformers.models.minicpm3 import minicpm3_attention_forward convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward) elif model.config.model_type == "minicpmv": diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index deec1551..13ad662e 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -69,6 +69,16 @@ def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor): return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x)) +def mlp_silu_forward(self, x: torch.Tensor): + from ipex_llm.transformers.models.utils import SILU + return fuse_mlp_base(self, SILU, x) + + +def mlp_gelu_forward(self, x: torch.Tensor): + from ipex_llm.transformers.models.utils import GELU + return fuse_mlp_base(self, GELU, x) + + def attention_softmax(attn_weights: torch.Tensor, training: bool): if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training: import xe_addons diff --git a/python/llm/src/ipex_llm/transformers/models/gemma2.py b/python/llm/src/ipex_llm/transformers/models/gemma2.py index 07f8314a..0d4584f2 100644 --- a/python/llm/src/ipex_llm/transformers/models/gemma2.py +++ b/python/llm/src/ipex_llm/transformers/models/gemma2.py @@ -34,7 +34,7 @@ import torch from typing import Optional, Tuple -from ipex_llm.transformers.models.common import merge_qkv_base, fuse_mlp_base +from ipex_llm.transformers.models.common import merge_qkv_base from ipex_llm.transformers.models.utils import GELU from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal from transformers.cache_utils import Cache @@ -184,7 +184,3 @@ def gemma2_attention_forward( attn_weights = None return attn_output, attn_weights, past_key_value - - -def gemma2_mlp_forward(self, x: torch.Tensor): - return fuse_mlp_base(self, GELU, x) diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm3.py b/python/llm/src/ipex_llm/transformers/models/minicpm3.py index c30c07e4..a47b7647 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm3.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm3.py @@ -92,8 +92,9 @@ def minicpm3_attention_forward( xe_addons.rotary_half_inplaced(inv_freq, position_ids, query_states[:, :, :, self.qk_nope_head_dim:], key_states[:, :, :, self.qk_nope_head_dim:]) - query_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor - key_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor + if self.rotary_emb.scaling_factor != 1.0: + query_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor + key_states[:, :, :, self.qk_nope_head_dim:] *= self.rotary_emb.scaling_factor else: invalidInputError(f"unknown rope method: {self.rotary_emb.__class__.__name__}") else: