[LLM] Add support for empty activation (#9664)
* Add support for empty activation, e.g., [0, 4096]. Empty activation is allowed by PyTorch. * Add comments.
This commit is contained in:
parent
284e7697b1
commit
5b0e7e308c
1 changed files with 10 additions and 3 deletions
|
|
@ -448,9 +448,18 @@ class LowBitLinear(nn.Linear):
|
||||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
self.bias.data = self.bias.data.to(x.dtype)
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
|
# [batch, input_num, in_len]
|
||||||
|
# input_num == token num for Transformer
|
||||||
x_shape = x.shape
|
x_shape = x.shape
|
||||||
x_2d = x.view(-1, x_shape[-1])
|
# Output shape, e.g., [batch, input_num, out_len]
|
||||||
|
new_shape = x_shape[:-1] + (self.out_len,)
|
||||||
|
# Activation is empty tensor, e.g., [1, 0, 4096]
|
||||||
|
if 0 in x_shape:
|
||||||
|
# return empty tensor with output shape, x.dtype and x.device
|
||||||
|
return torch.empty(new_shape, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
x_2d = x.view(-1, x_shape[-1])
|
||||||
|
# x0 for weight
|
||||||
x0 = self.weight.data
|
x0 = self.weight.data
|
||||||
|
|
||||||
if x0.device.type == "xpu":
|
if x0.device.type == "xpu":
|
||||||
|
|
@ -489,7 +498,6 @@ class LowBitLinear(nn.Linear):
|
||||||
else:
|
else:
|
||||||
result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype,
|
result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype,
|
||||||
input_seq_size)
|
input_seq_size)
|
||||||
new_shape = x_shape[:-1] + (self.out_len,)
|
|
||||||
result = result.view(new_shape)
|
result = result.view(new_shape)
|
||||||
if self.mp_group is not None:
|
if self.mp_group is not None:
|
||||||
from deepspeed import comm as dist
|
from deepspeed import comm as dist
|
||||||
|
|
@ -514,7 +522,6 @@ class LowBitLinear(nn.Linear):
|
||||||
else:
|
else:
|
||||||
# Weight does not need a convert
|
# Weight does not need a convert
|
||||||
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
|
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
|
||||||
new_shape = x_shape[:-1] + (self.out_len,)
|
|
||||||
result = result.view(new_shape)
|
result = result.view(new_shape)
|
||||||
# allreduce to combine partial results and add bias if necessary
|
# allreduce to combine partial results and add bias if necessary
|
||||||
if self.mp_group is not None:
|
if self.mp_group is not None:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue