diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 63543d20..c51780ff 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -665,5 +665,16 @@ class BF16Linear(nn.Linear): if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) + # If x.shape>3, F.linear will use bmm, accounting for performance degradation. + original_shape = x.shape + # Convert to 2D shape + if len(original_shape) > 2: + x = x.reshape(-1, original_shape[-1]) + result = F.linear(x, self.weight, self.bias) + + # Convert to original shape + if len(original_shape) > 2: + result = result.reshape(*original_shape[:-1], result.shape[-1]) + return result.to(x.dtype)