Fix bug that torch.ops.torch_ipex.matmul_bias_out cannot work on Linux MTL for short input (#11292)

This commit is contained in:
Yuwen Hu 2024-06-12 19:12:57 +08:00 committed by GitHub
parent b61f6e3ab1
commit 8edcdeb0e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -813,14 +813,17 @@ class FP16Linear(nn.Linear):
self.weight.data = self.weight.data.to(x.dtype)
if not self.use_esimd_kernel(x):
if get_ipex_version() < "2.1.10+xpu":
if get_ipex_version() < "2.1.10+xpu" \
or get_xpu_device_type(x) not in ["arc", "flex", "pvc"]:
if self.weight_type == 2:
self.weight = self.weight.transpose(0, 1).contiguous()
self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(),
requires_grad=False)
self.weight_type = 1
result = F.linear(x, self.weight, self.bias)
else:
if self.weight_type == 1:
self.weight = self.weight.transpose(0, 1).contiguous()
self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(),
requires_grad=False)
self.weight_type = 2
result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
if self.mp_group is not None: