[LLM]Solve the problem of calling bmm operator in BF16Linear (#9924)
* Solve the problem of calling bmm operator in BF16Linear
This commit is contained in:
parent
e403e4a8b7
commit
18cd1f1432
1 changed files with 11 additions and 0 deletions
|
|
@ -665,5 +665,16 @@ class BF16Linear(nn.Linear):
|
||||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
self.bias.data = self.bias.data.to(x.dtype)
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
|
# If x.shape>3, F.linear will use bmm, accounting for performance degradation.
|
||||||
|
original_shape = x.shape
|
||||||
|
# Convert to 2D shape
|
||||||
|
if len(original_shape) > 2:
|
||||||
|
x = x.reshape(-1, original_shape[-1])
|
||||||
|
|
||||||
result = F.linear(x, self.weight, self.bias)
|
result = F.linear(x, self.weight, self.bias)
|
||||||
|
|
||||||
|
# Convert to original shape
|
||||||
|
if len(original_shape) > 2:
|
||||||
|
result = result.reshape(*original_shape[:-1], result.shape[-1])
|
||||||
|
|
||||||
return result.to(x.dtype)
|
return result.to(x.dtype)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue