LLM: GPU QLoRA update to bf16 to accelerate gradient checkpointing (#9499)

* update to bf16 to accelerate gradient checkpoint

* add utils and fix ut
This commit is contained in:
Ruonan Wang 2023-11-21 17:08:36 +08:00 committed by GitHub
parent 3e39828420
commit 076d106ef5
4 changed files with 27 additions and 11 deletions

View file

@ -51,7 +51,8 @@ import intel_extension_for_pytorch as ipex
from bigdl.llm.transformers import AutoModelForCausalLM from bigdl.llm.transformers import AutoModelForCausalLM
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model # import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
cast_lora_weight
def get_int_from_env(env_keys, default): def get_int_from_env(env_keys, default):
"""Returns the first positive env value found in the `env_keys` list or the default.""" """Returns the first positive env value found in the `env_keys` list or the default."""
@ -76,6 +77,7 @@ def train(
data_path: str = "yahma/alpaca-cleaned", data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "./bigdl-qlora-alpaca", output_dir: str = "./bigdl-qlora-alpaca",
# training hyperparams # training hyperparams
bf16: bool = True, # default to bf16
batch_size: int = 128, batch_size: int = 128,
micro_batch_size: int = 2, # default to be 2, limited by GPU memory micro_batch_size: int = 2, # default to be 2, limited by GPU memory
num_epochs: int = 3, num_epochs: int = 3,
@ -301,6 +303,9 @@ def train(
# model.is_parallelizable = True # model.is_parallelizable = True
# model.model_parallel = True # model.model_parallel = True
if bf16:
cast_lora_weight(model, torch.bfloat16)
trainer = transformers.Trainer( trainer = transformers.Trainer(
model=model, model=model,
train_dataset=train_data, train_dataset=train_data,

View file

@ -72,16 +72,11 @@ if __name__ == "__main__":
0 0
].self_attn.q_proj.weight ].self_attn.q_proj.weight
assert torch.allclose(first_weight_old, first_weight)
# merge weights - new merging method from peft # merge weights - new merging method from peft
lora_model = lora_model.merge_and_unload() lora_model = lora_model.merge_and_unload()
lora_model.train(False) lora_model.train(False)
# did we do anything?
assert not torch.allclose(first_weight_old, first_weight)
lora_model_sd = lora_model.state_dict() lora_model_sd = lora_model.state_dict()
deloreanized_sd = { deloreanized_sd = {
k.replace("base_model.model.", ""): v k.replace("base_model.model.", ""): v

View file

@ -19,10 +19,10 @@ import os
import transformers import transformers
from transformers import LlamaTokenizer from transformers import LlamaTokenizer
from peft import LoraConfig from peft import LoraConfig
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training, \
cast_lora_weight
from bigdl.llm.transformers import AutoModelForCausalLM from bigdl.llm.transformers import AutoModelForCausalLM
from datasets import load_dataset from datasets import load_dataset
import argparse import argparse
@ -61,6 +61,8 @@ if __name__ == "__main__":
task_type="CAUSAL_LM" task_type="CAUSAL_LM"
) )
model = get_peft_model(model, config) model = get_peft_model(model, config)
cast_lora_weight(model, torch.bfloat16)
tokenizer.pad_token_id = 0 tokenizer.pad_token_id = 0
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
trainer = transformers.Trainer( trainer = transformers.Trainer(
@ -73,7 +75,6 @@ if __name__ == "__main__":
max_steps=200, max_steps=200,
learning_rate=2e-5, learning_rate=2e-5,
save_steps=100, save_steps=100,
# fp16=True,
bf16=True, # bf16 is more stable in training bf16=True, # bf16 is more stable in training
logging_steps=20, logging_steps=20,
output_dir="outputs", output_dir="outputs",

View file

@ -87,7 +87,10 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
autocast_dtype = get_autocast_dtype(x) autocast_dtype = get_autocast_dtype(x)
if autocast_dtype is not None: if x.device.type == "xpu":
# force to use bf16 on gpu
x = x.to(torch.bfloat16)
elif autocast_dtype is not None:
x = x.to(autocast_dtype) x = x.to(autocast_dtype)
result = super().forward(x) result = super().forward(x)
@ -95,7 +98,7 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
return result return result
elif self.r[self.active_adapter] > 0: elif self.r[self.active_adapter] > 0:
result = result.clone() result = result.clone()
if autocast_dtype is None: if autocast_dtype is None and x.device.type == "cpu":
expected_dtype = result.dtype expected_dtype = result.dtype
x = x.to(self.lora_A[self.active_adapter].weight.dtype) x = x.to(self.lora_A[self.active_adapter].weight.dtype)
output = ( output = (
@ -357,3 +360,15 @@ Accelerator._prepare_ipex = patch_prepare_ipex
# patch transformer for xpu DDP traing # patch transformer for xpu DDP traing
from transformers import TrainingArguments from transformers import TrainingArguments
TrainingArguments._setup_devices = _setup_devices TrainingArguments._setup_devices = _setup_devices
def cast_lora_weight(model, dtype=torch.bfloat16):
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
module = module.to(dtype)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if module.weight.dtype == torch.float32:
module = module.to(dtype)