remove unnecessary ipex kernel usage (#12649)
This commit is contained in:
parent
9f8b134889
commit
502461d836
2 changed files with 5 additions and 60 deletions
|
|
@ -1984,16 +1984,9 @@ def _optimize_post(model):
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.yuan import yuan_attention_forward
|
from ipex_llm.transformers.models.yuan import yuan_attention_forward
|
||||||
# from ipex_llm.transformers.models.yuan import yuan_mlp_forward
|
convert_forward(model, module.YuanAttention, yuan_attention_forward)
|
||||||
convert_forward(model,
|
# from ipex_llm.transformers.models.common import mlp_silu_forward
|
||||||
module.YuanAttention,
|
# convert_forward(model, module.YuanMLP, mlp_silu_forward)
|
||||||
yuan_attention_forward
|
|
||||||
)
|
|
||||||
# disable able mlp_forward for quantize_kv on mtl.
|
|
||||||
# convert_forward(model,
|
|
||||||
# module.YuanMLP,
|
|
||||||
# yuan_mlp_forward
|
|
||||||
# )
|
|
||||||
elif model.config.model_type == 'bert' and (
|
elif model.config.model_type == 'bert' and (
|
||||||
not model.config.is_decoder and
|
not model.config.is_decoder and
|
||||||
model.config.position_embedding_type == "absolute"
|
model.config.position_embedding_type == "absolute"
|
||||||
|
|
|
||||||
|
|
@ -20,17 +20,15 @@
|
||||||
# https://huggingface.co/IEITYuan/Yuan2-2B-hf/blob/7ab7b3c18eb8e5232ce2a3f720d4e6f4b53a2806/README.md#%E5%A3%B0%E6%98%8E%E4%B8%8E%E5%8D%8F%E8%AE%AEterms-and-conditions
|
# https://huggingface.co/IEITYuan/Yuan2-2B-hf/blob/7ab7b3c18eb8e5232ce2a3f720d4e6f4b53a2806/README.md#%E5%A3%B0%E6%98%8E%E4%B8%8E%E5%8D%8F%E8%AE%AEterms-and-conditions
|
||||||
#
|
#
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
||||||
mlp_fusion_check, fp16_fusion_check
|
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import SILU, update_past_key_value
|
from ipex_llm.transformers.models.utils import update_past_key_value
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -98,52 +96,6 @@ def yuan_localized_filtering_forward(
|
||||||
return lf_output
|
return lf_output
|
||||||
|
|
||||||
|
|
||||||
def yuan_mlp_forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
residual=None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
|
||||||
bsz, hidden_size = x_2d.shape
|
|
||||||
qtype = getattr(self.up_proj, "qtype", None)
|
|
||||||
if mlp_fusion_check(x_2d, qtype, self.training):
|
|
||||||
import xe_linear
|
|
||||||
if not x_2d.is_contiguous():
|
|
||||||
x_2d = x_2d.contiguous()
|
|
||||||
out = self.down_proj(xe_linear.mlp_forward_xpu(
|
|
||||||
x_2d, self.up_proj.weight.data, self.gate_proj.weight.data,
|
|
||||||
x_2d.shape[0], x_2d.shape[1], self.up_proj.out_len,
|
|
||||||
SILU, qtype
|
|
||||||
))
|
|
||||||
if residual is not None:
|
|
||||||
return out + residual
|
|
||||||
else:
|
|
||||||
return out
|
|
||||||
elif fp16_fusion_check(self.up_proj, x, self.training) and \
|
|
||||||
hidden_size == 4096 and bsz == 1:
|
|
||||||
hidden_states1 = torch.ops.torch_ipex.mm_silu(x, self.up_proj.weight)
|
|
||||||
hidden_states = torch.ops.torch_ipex.mm_resmul(
|
|
||||||
x, self.gate_proj.weight, hidden_states1
|
|
||||||
)
|
|
||||||
if residual is None:
|
|
||||||
hidden_states = torch.matmul(hidden_states, self.down_proj.weight)
|
|
||||||
else:
|
|
||||||
attn_output = torch.addmm(
|
|
||||||
residual.flatten(0, -2),
|
|
||||||
hidden_states.flatten(0, -2),
|
|
||||||
self.down_proj.weight,
|
|
||||||
beta=1,
|
|
||||||
)
|
|
||||||
hidden_states = attn_output.view(x.shape)
|
|
||||||
return hidden_states
|
|
||||||
else:
|
|
||||||
out = self.down_proj(self.act_fn(self.up_proj(x)) * self.gate_proj(x))
|
|
||||||
if residual is not None:
|
|
||||||
return out + residual
|
|
||||||
else:
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def yuan_attention_forward(
|
def yuan_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue