This commit is contained in:
Yishuo Wang 2024-12-04 16:37:39 +08:00 committed by GitHub
parent 7ff4533b39
commit e0bf0054e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 5 additions and 7 deletions

View file

@ -116,18 +116,18 @@ def padding_qkv_hd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
from ipex_llm.transformers.models.utils import mlp_fusion_check
x_2d = x.view(-1, x.size(-1))
qtype = getattr(module.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, module.training):
if mlp_fusion_check(x, qtype, module.training):
import xe_linear
x_2d = x_2d.contiguous()
return module.down_proj(
x_2d = x.contiguous().view(-1, x.size(-1))
output = module.down_proj(
xe_linear.mlp_forward_xpu(
x_2d, module.gate_proj.weight.data, module.up_proj.weight.data,
x_2d.size(0), x_2d.size(1), module.gate_proj.out_len,
act, qtype
)
)
return output.view(x.shape)
else:
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))

View file

@ -336,9 +336,7 @@ def use_sdp_non_causal(head_dim, device, dtype):
def mlp_fusion_check(x, qtype, training):
invalidInputError(x.dim() == 2,
"Here input x's dim should be 2.")
if x.shape[0] != 1:
if x.numel() // x.size(-1) != 1:
return False
if x.device.type != 'xpu':
return False