LLM: fix unlora module in qlora finetune (#9621)

* fix unlora module

* split train and inference
This commit is contained in:
Ruonan Wang 2023-12-07 16:32:02 +08:00 committed by GitHub
parent 3811cf43c9
commit d9b0c01de3
2 changed files with 21 additions and 2 deletions

View file

@ -50,6 +50,7 @@ from torch import Tensor, device, dtype, nn
from operator import mul
from functools import reduce
from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
from bigdl.llm.transformers.utils import get_autocast_dtype
T = TypeVar("T", bound="torch.nn.Module")
@ -433,8 +434,17 @@ class LowBitLinear(nn.Linear):
self.qtype = qtype
self.conver_to_half = conver_to_half
self.mp_group = mp_group
self.compute_dtype = None # only for training
def forward(self, x: torch.Tensor):
if self.training:
# below logic is only for training
autocast_dtype = get_autocast_dtype(x)
if self.compute_dtype is not None and x.device.type == "xpu":
x = x.to(self.compute_dtype) # solve GC issue for unlora module
elif autocast_dtype is not None:
x = x.to(autocast_dtype)
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
@ -457,9 +467,16 @@ class LowBitLinear(nn.Linear):
x_2d = x_2d.contiguous()
input_seq_size = x_shape[1]
if self.training and x_2d.requires_grad:
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
if self.training:
# training path
if x_2d.requires_grad:
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
else:
result = linear_q4_0.forward_new(x_2d, self.weight.data,
self.weight.qtype,
input_seq_size)
else:
# inference path
# current workaround to reduce first token latency of fp32 input
# sometimes fp16 cause nan and training instability
# disable the conversion when training

View file

@ -391,6 +391,8 @@ TrainingArguments._setup_devices = _setup_devices
def cast_lora_weight(model, dtype=torch.bfloat16):
for name, module in model.named_modules():
if isinstance(module, LowBitLinear):
module.compute_dtype = dtype
if isinstance(module, LoraLayer):
module = module.to(dtype)
if 'norm' in name: