fix fp16 linear (#12250)
This commit is contained in:
parent
e8cf7f32f5
commit
88dc120a4c
1 changed files with 2 additions and 1 deletions
|
|
@ -886,7 +886,8 @@ class FP16Linear(nn.Linear):
|
||||||
self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(),
|
self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
self.weight_type = 2
|
self.weight_type = 2
|
||||||
result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
|
result = torch.ops.torch_ipex.matmul_bias_out(x.contiguous(),
|
||||||
|
self.weight, self.bias)
|
||||||
if self.mp_group is not None:
|
if self.mp_group is not None:
|
||||||
if get_use_vllm():
|
if get_use_vllm():
|
||||||
result = self.mp_group.all_reduce(result)
|
result = self.mp_group.all_reduce(result)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue