diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 3e332878..5b7de52c 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -116,18 +116,18 @@ def padding_qkv_hd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor): from ipex_llm.transformers.models.utils import mlp_fusion_check - x_2d = x.view(-1, x.size(-1)) qtype = getattr(module.gate_proj, "qtype", None) - if mlp_fusion_check(x_2d, qtype, module.training): + if mlp_fusion_check(x, qtype, module.training): import xe_linear - x_2d = x_2d.contiguous() - return module.down_proj( + x_2d = x.contiguous().view(-1, x.size(-1)) + output = module.down_proj( xe_linear.mlp_forward_xpu( x_2d, module.gate_proj.weight.data, module.up_proj.weight.data, x_2d.size(0), x_2d.size(1), module.gate_proj.out_len, act, qtype ) ) + return output.view(x.shape) else: return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x)) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 463ff4bc..351ce689 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -336,9 +336,7 @@ def use_sdp_non_causal(head_dim, device, dtype): def mlp_fusion_check(x, qtype, training): - invalidInputError(x.dim() == 2, - "Here input x's dim should be 2.") - if x.shape[0] != 1: + if x.numel() // x.size(-1) != 1: return False if x.device.type != 'xpu': return False