[LLM]Solve the problem of calling bmm operator in BF16Linear (#9924)

* Solve the problem of calling bmm operator in BF16Linear
This commit is contained in:
Ziteng Zhang 2024-01-17 18:08:35 +08:00 committed by GitHub
parent e403e4a8b7
commit 18cd1f1432

View file

@ -665,5 +665,16 @@ class BF16Linear(nn.Linear):
if self.bias is not None and self.bias.dtype != x.dtype: if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype) self.bias.data = self.bias.data.to(x.dtype)
# If x.shape>3, F.linear will use bmm, accounting for performance degradation.
original_shape = x.shape
# Convert to 2D shape
if len(original_shape) > 2:
x = x.reshape(-1, original_shape[-1])
result = F.linear(x, self.weight, self.bias) result = F.linear(x, self.weight, self.bias)
# Convert to original shape
if len(original_shape) > 2:
result = result.reshape(*original_shape[:-1], result.shape[-1])
return result.to(x.dtype) return result.to(x.dtype)