small fix (#12493)
This commit is contained in:
parent
7ff4533b39
commit
e0bf0054e1
2 changed files with 5 additions and 7 deletions
|
|
@ -116,18 +116,18 @@ def padding_qkv_hd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|||
|
||||
def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
||||
x_2d = x.view(-1, x.size(-1))
|
||||
qtype = getattr(module.gate_proj, "qtype", None)
|
||||
if mlp_fusion_check(x_2d, qtype, module.training):
|
||||
if mlp_fusion_check(x, qtype, module.training):
|
||||
import xe_linear
|
||||
x_2d = x_2d.contiguous()
|
||||
return module.down_proj(
|
||||
x_2d = x.contiguous().view(-1, x.size(-1))
|
||||
output = module.down_proj(
|
||||
xe_linear.mlp_forward_xpu(
|
||||
x_2d, module.gate_proj.weight.data, module.up_proj.weight.data,
|
||||
x_2d.size(0), x_2d.size(1), module.gate_proj.out_len,
|
||||
act, qtype
|
||||
)
|
||||
)
|
||||
return output.view(x.shape)
|
||||
else:
|
||||
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
|
||||
|
||||
|
|
|
|||
|
|
@ -336,9 +336,7 @@ def use_sdp_non_causal(head_dim, device, dtype):
|
|||
|
||||
|
||||
def mlp_fusion_check(x, qtype, training):
|
||||
invalidInputError(x.dim() == 2,
|
||||
"Here input x's dim should be 2.")
|
||||
if x.shape[0] != 1:
|
||||
if x.numel() // x.size(-1) != 1:
|
||||
return False
|
||||
if x.device.type != 'xpu':
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in a new issue