From 18cd1f1432f038e01e08ce7496f467976fe9a0d3 Mon Sep 17 00:00:00 2001 From: Ziteng Zhang <87107332+Jasonzzt@users.noreply.github.com> Date: Wed, 17 Jan 2024 18:08:35 +0800 Subject: [PATCH] [LLM]Solve the problem of calling bmm operator in BF16Linear (#9924) * Solve the problem of calling bmm operator in BF16Linear --- .../llm/src/bigdl/llm/transformers/low_bit_linear.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 63543d20..c51780ff 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -665,5 +665,16 @@ class BF16Linear(nn.Linear): if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) + # If x.shape>3, F.linear will use bmm, accounting for performance degradation. + original_shape = x.shape + # Convert to 2D shape + if len(original_shape) > 2: + x = x.reshape(-1, original_shape[-1]) + result = F.linear(x, self.weight, self.bias) + + # Convert to original shape + if len(original_shape) > 2: + result = result.reshape(*original_shape[:-1], result.shape[-1]) + return result.to(x.dtype)