[LLM] Add support for empty activation (#9664)
* Add support for empty activation, e.g., [0, 4096]. Empty activation is allowed by PyTorch. * Add comments.
This commit is contained in:
parent
284e7697b1
commit
5b0e7e308c
1 changed files with 10 additions and 3 deletions
|
|
@ -448,9 +448,18 @@ class LowBitLinear(nn.Linear):
|
|||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
# [batch, input_num, in_len]
|
||||
# input_num == token num for Transformer
|
||||
x_shape = x.shape
|
||||
x_2d = x.view(-1, x_shape[-1])
|
||||
# Output shape, e.g., [batch, input_num, out_len]
|
||||
new_shape = x_shape[:-1] + (self.out_len,)
|
||||
# Activation is empty tensor, e.g., [1, 0, 4096]
|
||||
if 0 in x_shape:
|
||||
# return empty tensor with output shape, x.dtype and x.device
|
||||
return torch.empty(new_shape, dtype=x.dtype, device=x.device)
|
||||
|
||||
x_2d = x.view(-1, x_shape[-1])
|
||||
# x0 for weight
|
||||
x0 = self.weight.data
|
||||
|
||||
if x0.device.type == "xpu":
|
||||
|
|
@ -489,7 +498,6 @@ class LowBitLinear(nn.Linear):
|
|||
else:
|
||||
result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype,
|
||||
input_seq_size)
|
||||
new_shape = x_shape[:-1] + (self.out_len,)
|
||||
result = result.view(new_shape)
|
||||
if self.mp_group is not None:
|
||||
from deepspeed import comm as dist
|
||||
|
|
@ -514,7 +522,6 @@ class LowBitLinear(nn.Linear):
|
|||
else:
|
||||
# Weight does not need a convert
|
||||
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
|
||||
new_shape = x_shape[:-1] + (self.out_len,)
|
||||
result = result.view(new_shape)
|
||||
# allreduce to combine partial results and add bias if necessary
|
||||
if self.mp_group is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue