LLM: GPU QLoRA update to bf16 to accelerate gradient checkpointing (#9499)
* update to bf16 to accelerate gradient checkpoint * add utils and fix ut
This commit is contained in:
		
							parent
							
								
									3e39828420
								
							
						
					
					
						commit
						076d106ef5
					
				
					 4 changed files with 27 additions and 11 deletions
				
			
		| 
						 | 
				
			
			@ -51,7 +51,8 @@ import intel_extension_for_pytorch as ipex
 | 
			
		|||
from bigdl.llm.transformers import AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
 | 
			
		||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
 | 
			
		||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
 | 
			
		||||
    cast_lora_weight
 | 
			
		||||
 | 
			
		||||
def get_int_from_env(env_keys, default):
 | 
			
		||||
    """Returns the first positive env value found in the `env_keys` list or the default."""
 | 
			
		||||
| 
						 | 
				
			
			@ -76,6 +77,7 @@ def train(
 | 
			
		|||
    data_path: str = "yahma/alpaca-cleaned",
 | 
			
		||||
    output_dir: str = "./bigdl-qlora-alpaca",
 | 
			
		||||
    # training hyperparams
 | 
			
		||||
    bf16: bool = True,  # default to bf16
 | 
			
		||||
    batch_size: int = 128,
 | 
			
		||||
    micro_batch_size: int = 2,  # default to be 2, limited by GPU memory
 | 
			
		||||
    num_epochs: int = 3,
 | 
			
		||||
| 
						 | 
				
			
			@ -301,6 +303,9 @@ def train(
 | 
			
		|||
    #     model.is_parallelizable = True
 | 
			
		||||
    #     model.model_parallel = True
 | 
			
		||||
 | 
			
		||||
    if bf16:
 | 
			
		||||
        cast_lora_weight(model, torch.bfloat16)
 | 
			
		||||
 | 
			
		||||
    trainer = transformers.Trainer(
 | 
			
		||||
        model=model,
 | 
			
		||||
        train_dataset=train_data,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -72,16 +72,11 @@ if __name__ == "__main__":
 | 
			
		|||
        0
 | 
			
		||||
    ].self_attn.q_proj.weight
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(first_weight_old, first_weight)
 | 
			
		||||
 | 
			
		||||
    # merge weights - new merging method from peft
 | 
			
		||||
    lora_model = lora_model.merge_and_unload()
 | 
			
		||||
 | 
			
		||||
    lora_model.train(False)
 | 
			
		||||
 | 
			
		||||
    # did we do anything?
 | 
			
		||||
    assert not torch.allclose(first_weight_old, first_weight)
 | 
			
		||||
 | 
			
		||||
    lora_model_sd = lora_model.state_dict()
 | 
			
		||||
    deloreanized_sd = {
 | 
			
		||||
        k.replace("base_model.model.", ""): v
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,10 +19,10 @@ import os
 | 
			
		|||
 | 
			
		||||
import transformers
 | 
			
		||||
from transformers import LlamaTokenizer
 | 
			
		||||
 | 
			
		||||
from peft import LoraConfig
 | 
			
		||||
import intel_extension_for_pytorch as ipex
 | 
			
		||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
 | 
			
		||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training, \
 | 
			
		||||
    cast_lora_weight
 | 
			
		||||
from bigdl.llm.transformers import AutoModelForCausalLM
 | 
			
		||||
from datasets import load_dataset
 | 
			
		||||
import argparse
 | 
			
		||||
| 
						 | 
				
			
			@ -61,6 +61,8 @@ if __name__ == "__main__":
 | 
			
		|||
        task_type="CAUSAL_LM"
 | 
			
		||||
    )
 | 
			
		||||
    model = get_peft_model(model, config)
 | 
			
		||||
    cast_lora_weight(model, torch.bfloat16)
 | 
			
		||||
 | 
			
		||||
    tokenizer.pad_token_id = 0
 | 
			
		||||
    tokenizer.padding_side = "left"
 | 
			
		||||
    trainer = transformers.Trainer(
 | 
			
		||||
| 
						 | 
				
			
			@ -73,7 +75,6 @@ if __name__ == "__main__":
 | 
			
		|||
            max_steps=200,
 | 
			
		||||
            learning_rate=2e-5,
 | 
			
		||||
            save_steps=100,
 | 
			
		||||
            # fp16=True,
 | 
			
		||||
            bf16=True,  # bf16 is more stable in training
 | 
			
		||||
            logging_steps=20,
 | 
			
		||||
            output_dir="outputs",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -87,7 +87,10 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
 | 
			
		|||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        autocast_dtype = get_autocast_dtype(x)
 | 
			
		||||
        if autocast_dtype is not None:
 | 
			
		||||
        if x.device.type == "xpu":
 | 
			
		||||
            # force to use bf16 on gpu
 | 
			
		||||
            x = x.to(torch.bfloat16)
 | 
			
		||||
        elif autocast_dtype is not None:
 | 
			
		||||
            x = x.to(autocast_dtype)
 | 
			
		||||
        result = super().forward(x)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -95,7 +98,7 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
 | 
			
		|||
            return result
 | 
			
		||||
        elif self.r[self.active_adapter] > 0:
 | 
			
		||||
            result = result.clone()
 | 
			
		||||
            if autocast_dtype is None:
 | 
			
		||||
            if autocast_dtype is None and x.device.type == "cpu":
 | 
			
		||||
                expected_dtype = result.dtype
 | 
			
		||||
                x = x.to(self.lora_A[self.active_adapter].weight.dtype)
 | 
			
		||||
                output = (
 | 
			
		||||
| 
						 | 
				
			
			@ -357,3 +360,15 @@ Accelerator._prepare_ipex = patch_prepare_ipex
 | 
			
		|||
# patch transformer for xpu DDP traing
 | 
			
		||||
from transformers import TrainingArguments
 | 
			
		||||
TrainingArguments._setup_devices = _setup_devices
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cast_lora_weight(model, dtype=torch.bfloat16):
 | 
			
		||||
    for name, module in model.named_modules():
 | 
			
		||||
        if isinstance(module, LoraLayer):
 | 
			
		||||
            module = module.to(dtype)
 | 
			
		||||
        if 'norm' in name:
 | 
			
		||||
            module = module.to(torch.float32)
 | 
			
		||||
        if 'lm_head' in name or 'embed_tokens' in name:
 | 
			
		||||
            if hasattr(module, 'weight'):
 | 
			
		||||
                if module.weight.dtype == torch.float32:
 | 
			
		||||
                    module = module.to(dtype)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue