remove unnecessary ipex kernel usage (#12649)

This commit is contained in:
Yishuo Wang 2025-01-03 16:45:24 +08:00 committed by GitHub
parent 9f8b134889
commit 502461d836
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 5 additions and 60 deletions

View file

@ -1984,16 +1984,9 @@ def _optimize_post(model):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.yuan import yuan_attention_forward
# from ipex_llm.transformers.models.yuan import yuan_mlp_forward
convert_forward(model,
module.YuanAttention,
yuan_attention_forward
)
# disable able mlp_forward for quantize_kv on mtl.
# convert_forward(model,
# module.YuanMLP,
# yuan_mlp_forward
# )
convert_forward(model, module.YuanAttention, yuan_attention_forward)
# from ipex_llm.transformers.models.common import mlp_silu_forward
# convert_forward(model, module.YuanMLP, mlp_silu_forward)
elif model.config.model_type == 'bert' and (
not model.config.is_decoder and
model.config.position_embedding_type == "absolute"

View file

@ -20,17 +20,15 @@
# https://huggingface.co/IEITYuan/Yuan2-2B-hf/blob/7ab7b3c18eb8e5232ce2a3f720d4e6f4b53a2806/README.md#%E5%A3%B0%E6%98%8E%E4%B8%8E%E5%8D%8F%E8%AE%AEterms-and-conditions
#
import math
from typing import Optional, Tuple
import torch
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
mlp_fusion_check, fp16_fusion_check
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.models.utils import SILU, update_past_key_value
from ipex_llm.transformers.models.utils import update_past_key_value
from ipex_llm.transformers.models.utils import should_use_fuse_rope
@ -98,52 +96,6 @@ def yuan_localized_filtering_forward(
return lf_output
def yuan_mlp_forward(
self,
x: torch.Tensor,
residual=None
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
bsz, hidden_size = x_2d.shape
qtype = getattr(self.up_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training):
import xe_linear
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
out = self.down_proj(xe_linear.mlp_forward_xpu(
x_2d, self.up_proj.weight.data, self.gate_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.up_proj.out_len,
SILU, qtype
))
if residual is not None:
return out + residual
else:
return out
elif fp16_fusion_check(self.up_proj, x, self.training) and \
hidden_size == 4096 and bsz == 1:
hidden_states1 = torch.ops.torch_ipex.mm_silu(x, self.up_proj.weight)
hidden_states = torch.ops.torch_ipex.mm_resmul(
x, self.gate_proj.weight, hidden_states1
)
if residual is None:
hidden_states = torch.matmul(hidden_states, self.down_proj.weight)
else:
attn_output = torch.addmm(
residual.flatten(0, -2),
hidden_states.flatten(0, -2),
self.down_proj.weight,
beta=1,
)
hidden_states = attn_output.view(x.shape)
return hidden_states
else:
out = self.down_proj(self.act_fn(self.up_proj(x)) * self.gate_proj(x))
if residual is not None:
return out + residual
else:
return out
def yuan_attention_forward(
self,
hidden_states: torch.Tensor,