add mlp for gemma2 (#11678)
This commit is contained in:
parent
1da1f1dd0e
commit
c02003925b
3 changed files with 27 additions and 2 deletions
|
|
@ -1513,11 +1513,13 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
|
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
|
||||||
from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
|
from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
|
||||||
from ipex_llm.transformers.models.gemma2 import gemma2_model_forward
|
from ipex_llm.transformers.models.gemma2 import gemma2_model_forward
|
||||||
|
from ipex_llm.transformers.models.gemma2 import gemma2_mlp_forward
|
||||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention
|
||||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP
|
||||||
convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
|
convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
|
||||||
convert_forward(model, Gemma2Attention, gemma2_attention_forward)
|
convert_forward(model, Gemma2Attention, gemma2_attention_forward)
|
||||||
convert_forward(model, Gemma2Model, gemma2_model_forward)
|
convert_forward(model, Gemma2Model, gemma2_model_forward)
|
||||||
|
convert_forward(model, Gemma2MLP, gemma2_mlp_forward)
|
||||||
elif model.config.model_type == "Yi":
|
elif model.config.model_type == "Yi":
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
|
|
||||||
|
|
@ -41,3 +41,21 @@ def merge_qkv_base(module: torch.nn.Module, attention_class):
|
||||||
])
|
])
|
||||||
module.qkv_proj = qkv_proj
|
module.qkv_proj = qkv_proj
|
||||||
del module.q_proj, module.k_proj, module.v_proj
|
del module.q_proj, module.k_proj, module.v_proj
|
||||||
|
|
||||||
|
|
||||||
|
def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
|
||||||
|
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
||||||
|
x_2d = x.view(-1, x.size(-1))
|
||||||
|
qtype = getattr(module.gate_proj, "qtype", None)
|
||||||
|
if mlp_fusion_check(x_2d, qtype, module.training):
|
||||||
|
import xe_linear
|
||||||
|
x_2d = x_2d.contiguous()
|
||||||
|
return module.down_proj(
|
||||||
|
xe_linear.mlp_forward_xpu(
|
||||||
|
x_2d, module.gate_proj.weight.data, module.up_proj.weight.data,
|
||||||
|
x_2d.size(0), x_2d.size(1), module.gate_proj.out_len,
|
||||||
|
act, qtype
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,8 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
from ipex_llm.transformers.models.common import merge_qkv_base, fuse_mlp_base
|
||||||
|
from ipex_llm.transformers.models.utils import GELU
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2Attention
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2Attention
|
||||||
|
|
@ -177,3 +178,7 @@ def gemma2_attention_forward(
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def gemma2_mlp_forward(self, x: torch.Tensor):
|
||||||
|
return fuse_mlp_base(self, GELU, x)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue