[LLM] Add support for empty activation (#9664)

* Add support for empty activation, e.g., [0, 4096]. Empty activation is allowed by PyTorch.
* Add comments.
This commit is contained in:
Qiyuan Gong 2023-12-13 11:07:45 +08:00 committed by GitHub
parent 284e7697b1
commit 5b0e7e308c

View file

@ -448,9 +448,18 @@ class LowBitLinear(nn.Linear):
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
# [batch, input_num, in_len]
# input_num == token num for Transformer
x_shape = x.shape
x_2d = x.view(-1, x_shape[-1])
# Output shape, e.g., [batch, input_num, out_len]
new_shape = x_shape[:-1] + (self.out_len,)
# Activation is empty tensor, e.g., [1, 0, 4096]
if 0 in x_shape:
# return empty tensor with output shape, x.dtype and x.device
return torch.empty(new_shape, dtype=x.dtype, device=x.device)
x_2d = x.view(-1, x_shape[-1])
# x0 for weight
x0 = self.weight.data
if x0.device.type == "xpu":
@ -489,7 +498,6 @@ class LowBitLinear(nn.Linear):
else:
result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype,
input_seq_size)
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.mp_group is not None:
from deepspeed import comm as dist
@ -514,7 +522,6 @@ class LowBitLinear(nn.Linear):
else:
# Weight does not need a convert
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
# allreduce to combine partial results and add bias if necessary
if self.mp_group is not None: