fix fp16 linear (#12250)

This commit is contained in:
Yishuo Wang 2024-10-23 14:35:19 +08:00 committed by GitHub
parent e8cf7f32f5
commit 88dc120a4c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -886,7 +886,8 @@ class FP16Linear(nn.Linear):
self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(), self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(),
requires_grad=False) requires_grad=False)
self.weight_type = 2 self.weight_type = 2
result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias) result = torch.ops.torch_ipex.matmul_bias_out(x.contiguous(),
self.weight, self.bias)
if self.mp_group is not None: if self.mp_group is not None:
if get_use_vllm(): if get_use_vllm():
result = self.mp_group.all_reduce(result) result = self.mp_group.all_reduce(result)