LLM: fix unlora module in qlora finetune (#9621)
* fix unlora module * split train and inference
This commit is contained in:
parent
3811cf43c9
commit
d9b0c01de3
2 changed files with 21 additions and 2 deletions
|
|
@ -50,6 +50,7 @@ from torch import Tensor, device, dtype, nn
|
|||
from operator import mul
|
||||
from functools import reduce
|
||||
from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
|
||||
from bigdl.llm.transformers.utils import get_autocast_dtype
|
||||
|
||||
T = TypeVar("T", bound="torch.nn.Module")
|
||||
|
||||
|
|
@ -433,8 +434,17 @@ class LowBitLinear(nn.Linear):
|
|||
self.qtype = qtype
|
||||
self.conver_to_half = conver_to_half
|
||||
self.mp_group = mp_group
|
||||
self.compute_dtype = None # only for training
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.training:
|
||||
# below logic is only for training
|
||||
autocast_dtype = get_autocast_dtype(x)
|
||||
if self.compute_dtype is not None and x.device.type == "xpu":
|
||||
x = x.to(self.compute_dtype) # solve GC issue for unlora module
|
||||
elif autocast_dtype is not None:
|
||||
x = x.to(autocast_dtype)
|
||||
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
|
|
@ -457,9 +467,16 @@ class LowBitLinear(nn.Linear):
|
|||
x_2d = x_2d.contiguous()
|
||||
|
||||
input_seq_size = x_shape[1]
|
||||
if self.training and x_2d.requires_grad:
|
||||
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
|
||||
if self.training:
|
||||
# training path
|
||||
if x_2d.requires_grad:
|
||||
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
|
||||
else:
|
||||
result = linear_q4_0.forward_new(x_2d, self.weight.data,
|
||||
self.weight.qtype,
|
||||
input_seq_size)
|
||||
else:
|
||||
# inference path
|
||||
# current workaround to reduce first token latency of fp32 input
|
||||
# sometimes fp16 cause nan and training instability
|
||||
# disable the conversion when training
|
||||
|
|
|
|||
|
|
@ -391,6 +391,8 @@ TrainingArguments._setup_devices = _setup_devices
|
|||
|
||||
def cast_lora_weight(model, dtype=torch.bfloat16):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LowBitLinear):
|
||||
module.compute_dtype = dtype
|
||||
if isinstance(module, LoraLayer):
|
||||
module = module.to(dtype)
|
||||
if 'norm' in name:
|
||||
|
|
|
|||
Loading…
Reference in a new issue