diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 292c765a..13e27a05 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -654,7 +654,8 @@ class LowBitLinear(nn.Linear): else: w = self.weight.data - if use_batch_forward(x_2d, self.weight.qtype, self.out_len) and self.conver_to_half: + if use_batch_forward(x_2d, self.weight.qtype, self.out_len) and \ + (x_2d.dtype == torch.half or self.conver_to_half): import xe_batch result = xe_batch.batch_forward(x_2d, w, self.qtype) elif not is_training and self.conver_to_half \