LLM: fix unlora module in qlora finetune (#9621)
* fix unlora module * split train and inference
This commit is contained in:
		
							parent
							
								
									3811cf43c9
								
							
						
					
					
						commit
						d9b0c01de3
					
				
					 2 changed files with 21 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -50,6 +50,7 @@ from torch import Tensor, device, dtype, nn
 | 
			
		|||
from operator import mul
 | 
			
		||||
from functools import reduce
 | 
			
		||||
from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
 | 
			
		||||
from bigdl.llm.transformers.utils import get_autocast_dtype
 | 
			
		||||
 | 
			
		||||
T = TypeVar("T", bound="torch.nn.Module")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -433,8 +434,17 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
        self.qtype = qtype
 | 
			
		||||
        self.conver_to_half = conver_to_half
 | 
			
		||||
        self.mp_group = mp_group
 | 
			
		||||
        self.compute_dtype = None  # only for training
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        if self.training:
 | 
			
		||||
            # below logic is only for training
 | 
			
		||||
            autocast_dtype = get_autocast_dtype(x)
 | 
			
		||||
            if self.compute_dtype is not None and x.device.type == "xpu":
 | 
			
		||||
                x = x.to(self.compute_dtype)  # solve GC issue for unlora module
 | 
			
		||||
            elif autocast_dtype is not None:
 | 
			
		||||
                x = x.to(autocast_dtype)
 | 
			
		||||
 | 
			
		||||
        if self.bias is not None and self.bias.dtype != x.dtype:
 | 
			
		||||
            self.bias.data = self.bias.data.to(x.dtype)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -457,9 +467,16 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
                x_2d = x_2d.contiguous()
 | 
			
		||||
 | 
			
		||||
            input_seq_size = x_shape[1]
 | 
			
		||||
            if self.training and x_2d.requires_grad:
 | 
			
		||||
                result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
 | 
			
		||||
            if self.training:
 | 
			
		||||
                # training path
 | 
			
		||||
                if x_2d.requires_grad:
 | 
			
		||||
                    result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
 | 
			
		||||
                else:
 | 
			
		||||
                    result = linear_q4_0.forward_new(x_2d, self.weight.data,
 | 
			
		||||
                                                     self.weight.qtype,
 | 
			
		||||
                                                     input_seq_size)
 | 
			
		||||
            else:
 | 
			
		||||
                # inference path
 | 
			
		||||
                # current workaround to reduce first token latency of fp32 input
 | 
			
		||||
                # sometimes fp16 cause nan and training instability
 | 
			
		||||
                # disable the conversion when training
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -391,6 +391,8 @@ TrainingArguments._setup_devices = _setup_devices
 | 
			
		|||
 | 
			
		||||
def cast_lora_weight(model, dtype=torch.bfloat16):
 | 
			
		||||
    for name, module in model.named_modules():
 | 
			
		||||
        if isinstance(module, LowBitLinear):
 | 
			
		||||
            module.compute_dtype = dtype
 | 
			
		||||
        if isinstance(module, LoraLayer):
 | 
			
		||||
            module = module.to(dtype)
 | 
			
		||||
        if 'norm' in name:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue