[LLM]Solve the problem of calling bmm operator in BF16Linear (#9924)
* Solve the problem of calling bmm operator in BF16Linear
This commit is contained in:
parent
e403e4a8b7
commit
18cd1f1432
1 changed files with 11 additions and 0 deletions
|
|
@ -665,5 +665,16 @@ class BF16Linear(nn.Linear):
|
|||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
# If x.shape>3, F.linear will use bmm, accounting for performance degradation.
|
||||
original_shape = x.shape
|
||||
# Convert to 2D shape
|
||||
if len(original_shape) > 2:
|
||||
x = x.reshape(-1, original_shape[-1])
|
||||
|
||||
result = F.linear(x, self.weight, self.bias)
|
||||
|
||||
# Convert to original shape
|
||||
if len(original_shape) > 2:
|
||||
result = result.reshape(*original_shape[:-1], result.shape[-1])
|
||||
|
||||
return result.to(x.dtype)
|
||||
|
|
|
|||
Loading…
Reference in a new issue