Fix bug that torch.ops.torch_ipex.matmul_bias_out cannot work on Linux MTL for short input (#11292)
This commit is contained in:
parent
b61f6e3ab1
commit
8edcdeb0e7
1 changed files with 6 additions and 3 deletions
|
|
@ -813,14 +813,17 @@ class FP16Linear(nn.Linear):
|
|||
self.weight.data = self.weight.data.to(x.dtype)
|
||||
|
||||
if not self.use_esimd_kernel(x):
|
||||
if get_ipex_version() < "2.1.10+xpu":
|
||||
if get_ipex_version() < "2.1.10+xpu" \
|
||||
or get_xpu_device_type(x) not in ["arc", "flex", "pvc"]:
|
||||
if self.weight_type == 2:
|
||||
self.weight = self.weight.transpose(0, 1).contiguous()
|
||||
self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(),
|
||||
requires_grad=False)
|
||||
self.weight_type = 1
|
||||
result = F.linear(x, self.weight, self.bias)
|
||||
else:
|
||||
if self.weight_type == 1:
|
||||
self.weight = self.weight.transpose(0, 1).contiguous()
|
||||
self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(),
|
||||
requires_grad=False)
|
||||
self.weight_type = 2
|
||||
result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
|
||||
if self.mp_group is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue