diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index b429c3ce..86a689ee 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -813,14 +813,17 @@ class FP16Linear(nn.Linear): self.weight.data = self.weight.data.to(x.dtype) if not self.use_esimd_kernel(x): - if get_ipex_version() < "2.1.10+xpu": + if get_ipex_version() < "2.1.10+xpu" \ + or get_xpu_device_type(x) not in ["arc", "flex", "pvc"]: if self.weight_type == 2: - self.weight = self.weight.transpose(0, 1).contiguous() + self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(), + requires_grad=False) self.weight_type = 1 result = F.linear(x, self.weight, self.bias) else: if self.weight_type == 1: - self.weight = self.weight.transpose(0, 1).contiguous() + self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(), + requires_grad=False) self.weight_type = 2 result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias) if self.mp_group is not None: