From 88dc120a4c017bbda262ce92c0983747f0d152f0 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 23 Oct 2024 14:35:19 +0800 Subject: [PATCH] fix fp16 linear (#12250) --- python/llm/src/ipex_llm/transformers/low_bit_linear.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index d30126a6..cfbdf3ed 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -886,7 +886,8 @@ class FP16Linear(nn.Linear): self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(), requires_grad=False) self.weight_type = 2 - result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias) + result = torch.ops.torch_ipex.matmul_bias_out(x.contiguous(), + self.weight, self.bias) if self.mp_group is not None: if get_use_vllm(): result = self.mp_group.all_reduce(result)