small fix (#12493)
This commit is contained in:
parent
7ff4533b39
commit
e0bf0054e1
2 changed files with 5 additions and 7 deletions
|
|
@ -116,18 +116,18 @@ def padding_qkv_hd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
|
||||||
def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
|
def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
from ipex_llm.transformers.models.utils import mlp_fusion_check
|
||||||
x_2d = x.view(-1, x.size(-1))
|
|
||||||
qtype = getattr(module.gate_proj, "qtype", None)
|
qtype = getattr(module.gate_proj, "qtype", None)
|
||||||
if mlp_fusion_check(x_2d, qtype, module.training):
|
if mlp_fusion_check(x, qtype, module.training):
|
||||||
import xe_linear
|
import xe_linear
|
||||||
x_2d = x_2d.contiguous()
|
x_2d = x.contiguous().view(-1, x.size(-1))
|
||||||
return module.down_proj(
|
output = module.down_proj(
|
||||||
xe_linear.mlp_forward_xpu(
|
xe_linear.mlp_forward_xpu(
|
||||||
x_2d, module.gate_proj.weight.data, module.up_proj.weight.data,
|
x_2d, module.gate_proj.weight.data, module.up_proj.weight.data,
|
||||||
x_2d.size(0), x_2d.size(1), module.gate_proj.out_len,
|
x_2d.size(0), x_2d.size(1), module.gate_proj.out_len,
|
||||||
act, qtype
|
act, qtype
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
return output.view(x.shape)
|
||||||
else:
|
else:
|
||||||
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
|
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -336,9 +336,7 @@ def use_sdp_non_causal(head_dim, device, dtype):
|
||||||
|
|
||||||
|
|
||||||
def mlp_fusion_check(x, qtype, training):
|
def mlp_fusion_check(x, qtype, training):
|
||||||
invalidInputError(x.dim() == 2,
|
if x.numel() // x.size(-1) != 1:
|
||||||
"Here input x's dim should be 2.")
|
|
||||||
if x.shape[0] != 1:
|
|
||||||
return False
|
return False
|
||||||
if x.device.type != 'xpu':
|
if x.device.type != 'xpu':
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue