Support qlora in CPU (#9233)

* support qlora in CPU

* revert example

* fix style
This commit is contained in:
Yang Wang 2023-10-28 05:01:15 +08:00 committed by GitHub
parent 8838707009
commit 163d033616

View file

@ -323,6 +323,37 @@ class MatMulLowBit(torch.autograd.Function):
return grad_A, grad_weight, None
class MatMulLowBitCPU(torch.autograd.Function):
@staticmethod
def forward(ctx, A, weight):
if torch.is_autocast_enabled():
A = A.to(torch.get_autocast_dtype())
ctx.is_empty = False
x0_fp32 = ggml_int4_convert_fp32(weight.data, weight._shape,
weight._shape[0] * weight._shape[1])
result = torch.matmul(A, x0_fp32.T)
if any(ctx.needs_input_grad[:2]):
ctx.tensors = (A, weight)
else:
ctx.tensors = (None, None)
return result
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, _, = ctx.needs_input_grad
A, weight = ctx.tensors
grad_A, grad_weight = None, None
if req_gradA:
x0_fp32 = ggml_int4_convert_fp32(weight.data, weight._shape,
weight._shape[0] * weight._shape[1])
grad_A = torch.matmul(grad_output, x0_fp32.to(grad_output.dtype))
return grad_A, grad_weight, None
class LowBitLinear(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True,
conver_to_half=True, mp_group=None):
@ -388,16 +419,21 @@ class LowBitLinear(nn.Linear):
and self.qtype != FP4,
"NF3, NF4, FP4 and FP8 quantization are currently not"
" supported on CPU")
if IS_SERVER and (not IS_SPR) and \
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
result = F.linear(x, x0_fp32, self.bias)
else:
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.training and x.requires_grad:
result = MatMulLowBitCPU.apply(x, self.weight)
if self.bias is not None:
result += self.bias
result = result + self.bias
else:
if IS_SERVER and (not IS_SPR) and \
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
result = F.linear(x, x0_fp32, self.bias)
else:
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.bias is not None:
result += self.bias
return result