diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 169b7102..02a8fe6c 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1984,16 +1984,9 @@ def _optimize_post(model): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.yuan import yuan_attention_forward - # from ipex_llm.transformers.models.yuan import yuan_mlp_forward - convert_forward(model, - module.YuanAttention, - yuan_attention_forward - ) - # disable able mlp_forward for quantize_kv on mtl. - # convert_forward(model, - # module.YuanMLP, - # yuan_mlp_forward - # ) + convert_forward(model, module.YuanAttention, yuan_attention_forward) + # from ipex_llm.transformers.models.common import mlp_silu_forward + # convert_forward(model, module.YuanMLP, mlp_silu_forward) elif model.config.model_type == 'bert' and ( not model.config.is_decoder and model.config.position_embedding_type == "absolute" diff --git a/python/llm/src/ipex_llm/transformers/models/yuan.py b/python/llm/src/ipex_llm/transformers/models/yuan.py index afa99da1..e6d3ddbe 100644 --- a/python/llm/src/ipex_llm/transformers/models/yuan.py +++ b/python/llm/src/ipex_llm/transformers/models/yuan.py @@ -20,17 +20,15 @@ # https://huggingface.co/IEITYuan/Yuan2-2B-hf/blob/7ab7b3c18eb8e5232ce2a3f720d4e6f4b53a2806/README.md#%E5%A3%B0%E6%98%8E%E4%B8%8E%E5%8D%8F%E8%AE%AEterms-and-conditions # -import math from typing import Optional, Tuple import torch from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.models.common import scaled_dot_product_attention -from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ - mlp_fusion_check, fp16_fusion_check +from ipex_llm.transformers.models.utils import apply_rotary_pos_emb from ipex_llm.transformers.models.utils import use_quantize_kv_cache -from ipex_llm.transformers.models.utils import SILU, update_past_key_value +from ipex_llm.transformers.models.utils import update_past_key_value from ipex_llm.transformers.models.utils import should_use_fuse_rope @@ -98,52 +96,6 @@ def yuan_localized_filtering_forward( return lf_output -def yuan_mlp_forward( - self, - x: torch.Tensor, - residual=None -) -> torch.Tensor: - x_2d = x.view(-1, x.shape[-1]) - bsz, hidden_size = x_2d.shape - qtype = getattr(self.up_proj, "qtype", None) - if mlp_fusion_check(x_2d, qtype, self.training): - import xe_linear - if not x_2d.is_contiguous(): - x_2d = x_2d.contiguous() - out = self.down_proj(xe_linear.mlp_forward_xpu( - x_2d, self.up_proj.weight.data, self.gate_proj.weight.data, - x_2d.shape[0], x_2d.shape[1], self.up_proj.out_len, - SILU, qtype - )) - if residual is not None: - return out + residual - else: - return out - elif fp16_fusion_check(self.up_proj, x, self.training) and \ - hidden_size == 4096 and bsz == 1: - hidden_states1 = torch.ops.torch_ipex.mm_silu(x, self.up_proj.weight) - hidden_states = torch.ops.torch_ipex.mm_resmul( - x, self.gate_proj.weight, hidden_states1 - ) - if residual is None: - hidden_states = torch.matmul(hidden_states, self.down_proj.weight) - else: - attn_output = torch.addmm( - residual.flatten(0, -2), - hidden_states.flatten(0, -2), - self.down_proj.weight, - beta=1, - ) - hidden_states = attn_output.view(x.shape) - return hidden_states - else: - out = self.down_proj(self.act_fn(self.up_proj(x)) * self.gate_proj(x)) - if residual is not None: - return out + residual - else: - return out - - def yuan_attention_forward( self, hidden_states: torch.Tensor,