Support qlora in CPU (#9233)
* support qlora in CPU * revert example * fix style
This commit is contained in:
parent
8838707009
commit
163d033616
1 changed files with 45 additions and 9 deletions
|
|
@ -323,6 +323,37 @@ class MatMulLowBit(torch.autograd.Function):
|
||||||
return grad_A, grad_weight, None
|
return grad_A, grad_weight, None
|
||||||
|
|
||||||
|
|
||||||
|
class MatMulLowBitCPU(torch.autograd.Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, A, weight):
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
A = A.to(torch.get_autocast_dtype())
|
||||||
|
ctx.is_empty = False
|
||||||
|
x0_fp32 = ggml_int4_convert_fp32(weight.data, weight._shape,
|
||||||
|
weight._shape[0] * weight._shape[1])
|
||||||
|
result = torch.matmul(A, x0_fp32.T)
|
||||||
|
if any(ctx.needs_input_grad[:2]):
|
||||||
|
ctx.tensors = (A, weight)
|
||||||
|
else:
|
||||||
|
ctx.tensors = (None, None)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
if ctx.is_empty:
|
||||||
|
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
|
||||||
|
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||||
|
req_gradA, _, = ctx.needs_input_grad
|
||||||
|
A, weight = ctx.tensors
|
||||||
|
grad_A, grad_weight = None, None
|
||||||
|
if req_gradA:
|
||||||
|
x0_fp32 = ggml_int4_convert_fp32(weight.data, weight._shape,
|
||||||
|
weight._shape[0] * weight._shape[1])
|
||||||
|
grad_A = torch.matmul(grad_output, x0_fp32.to(grad_output.dtype))
|
||||||
|
return grad_A, grad_weight, None
|
||||||
|
|
||||||
|
|
||||||
class LowBitLinear(nn.Linear):
|
class LowBitLinear(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, qtype, bias=True,
|
def __init__(self, input_features, output_features, qtype, bias=True,
|
||||||
conver_to_half=True, mp_group=None):
|
conver_to_half=True, mp_group=None):
|
||||||
|
|
@ -388,16 +419,21 @@ class LowBitLinear(nn.Linear):
|
||||||
and self.qtype != FP4,
|
and self.qtype != FP4,
|
||||||
"NF3, NF4, FP4 and FP8 quantization are currently not"
|
"NF3, NF4, FP4 and FP8 quantization are currently not"
|
||||||
" supported on CPU")
|
" supported on CPU")
|
||||||
if IS_SERVER and (not IS_SPR) and \
|
if self.training and x.requires_grad:
|
||||||
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
|
result = MatMulLowBitCPU.apply(x, self.weight)
|
||||||
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
|
|
||||||
result = F.linear(x, x0_fp32, self.bias)
|
|
||||||
else:
|
|
||||||
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
|
|
||||||
new_shape = x_shape[:-1] + (self.out_len,)
|
|
||||||
result = result.view(new_shape)
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
result += self.bias
|
result = result + self.bias
|
||||||
|
else:
|
||||||
|
if IS_SERVER and (not IS_SPR) and \
|
||||||
|
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
|
||||||
|
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
|
||||||
|
result = F.linear(x, x0_fp32, self.bias)
|
||||||
|
else:
|
||||||
|
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
|
||||||
|
new_shape = x_shape[:-1] + (self.out_len,)
|
||||||
|
result = result.view(new_shape)
|
||||||
|
if self.bias is not None:
|
||||||
|
result += self.bias
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue