From 4c4f8d1663dd7149fde5b1bb701539b1da4f026a Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Mon, 9 Oct 2023 15:09:37 +0800 Subject: [PATCH] [LLM]Fix Arc falcon abnormal output issue (#9096) * update * update * fix error & style * fix style * update train * to input_seq_size --- .../src/bigdl/llm/transformers/low_bit_linear.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 8500f76a..0e3b7dba 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -288,10 +288,10 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor, class MatMulLowBit(torch.autograd.Function): @staticmethod - def forward(ctx, A, weight): + def forward(ctx, A, weight, input_seq_size): ctx.is_empty = False import linear_q4_0 - result = linear_q4_0.forward_new(A, weight.data, weight.qtype) + result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size) if any(ctx.needs_input_grad[:2]): ctx.tensors = (A, weight) else: @@ -304,14 +304,14 @@ class MatMulLowBit(torch.autograd.Function): if ctx.is_empty: bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - req_gradA, _ = ctx.needs_input_grad + req_gradA, _, _ = ctx.needs_input_grad A, weight = ctx.tensors grad_A, grad_weight = None, None if req_gradA: dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype) grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape)) - return grad_A, grad_weight + return grad_A, grad_weight, None class LowBitLinear(nn.Linear): @@ -353,10 +353,12 @@ class LowBitLinear(nn.Linear): # disable the conversion when training if self.conver_to_half and x_2d.shape[0] > 1 and x_2d.dtype == torch.float32: x_2d = x_2d.half() + input_seq_size = x_shape[1] if self.training and x_2d.requires_grad: - result = MatMulLowBit.apply(x_2d, self.weight) + result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size) else: - result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype) + result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype, + input_seq_size) new_shape = x_shape[:-1] + (self.out_len,) result = result.view(new_shape) if self.bias is not None: