From 163d03361659082fdf19a02d84cab7759cc92d11 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Sat, 28 Oct 2023 05:01:15 +0800 Subject: [PATCH] Support qlora in CPU (#9233) * support qlora in CPU * revert example * fix style --- .../bigdl/llm/transformers/low_bit_linear.py | 54 +++++++++++++++---- 1 file changed, 45 insertions(+), 9 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 7ea99f38..7d3266aa 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -323,6 +323,37 @@ class MatMulLowBit(torch.autograd.Function): return grad_A, grad_weight, None +class MatMulLowBitCPU(torch.autograd.Function): + + @staticmethod + def forward(ctx, A, weight): + if torch.is_autocast_enabled(): + A = A.to(torch.get_autocast_dtype()) + ctx.is_empty = False + x0_fp32 = ggml_int4_convert_fp32(weight.data, weight._shape, + weight._shape[0] * weight._shape[1]) + result = torch.matmul(A, x0_fp32.T) + if any(ctx.needs_input_grad[:2]): + ctx.tensors = (A, weight) + else: + ctx.tensors = (None, None) + return result + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + req_gradA, _, = ctx.needs_input_grad + A, weight = ctx.tensors + grad_A, grad_weight = None, None + if req_gradA: + x0_fp32 = ggml_int4_convert_fp32(weight.data, weight._shape, + weight._shape[0] * weight._shape[1]) + grad_A = torch.matmul(grad_output, x0_fp32.to(grad_output.dtype)) + return grad_A, grad_weight, None + + class LowBitLinear(nn.Linear): def __init__(self, input_features, output_features, qtype, bias=True, conver_to_half=True, mp_group=None): @@ -388,16 +419,21 @@ class LowBitLinear(nn.Linear): and self.qtype != FP4, "NF3, NF4, FP4 and FP8 quantization are currently not" " supported on CPU") - if IS_SERVER and (not IS_SPR) and \ - self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD: - x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length) - result = F.linear(x, x0_fp32, self.bias) - else: - result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype) - new_shape = x_shape[:-1] + (self.out_len,) - result = result.view(new_shape) + if self.training and x.requires_grad: + result = MatMulLowBitCPU.apply(x, self.weight) if self.bias is not None: - result += self.bias + result = result + self.bias + else: + if IS_SERVER and (not IS_SPR) and \ + self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD: + x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length) + result = F.linear(x, x0_fp32, self.bias) + else: + result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype) + new_shape = x_shape[:-1] + (self.out_len,) + result = result.view(new_shape) + if self.bias is not None: + result += self.bias return result