LLM: GPU QLoRA update to bf16 to accelerate gradient checkpointing (#9499)
* update to bf16 to accelerate gradient checkpoint * add utils and fix ut
This commit is contained in:
parent
3e39828420
commit
076d106ef5
4 changed files with 27 additions and 11 deletions
|
|
@ -51,7 +51,8 @@ import intel_extension_for_pytorch as ipex
|
|||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
|
||||
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
|
||||
cast_lora_weight
|
||||
|
||||
def get_int_from_env(env_keys, default):
|
||||
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
||||
|
|
@ -76,6 +77,7 @@ def train(
|
|||
data_path: str = "yahma/alpaca-cleaned",
|
||||
output_dir: str = "./bigdl-qlora-alpaca",
|
||||
# training hyperparams
|
||||
bf16: bool = True, # default to bf16
|
||||
batch_size: int = 128,
|
||||
micro_batch_size: int = 2, # default to be 2, limited by GPU memory
|
||||
num_epochs: int = 3,
|
||||
|
|
@ -301,6 +303,9 @@ def train(
|
|||
# model.is_parallelizable = True
|
||||
# model.model_parallel = True
|
||||
|
||||
if bf16:
|
||||
cast_lora_weight(model, torch.bfloat16)
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=train_data,
|
||||
|
|
|
|||
|
|
@ -72,16 +72,11 @@ if __name__ == "__main__":
|
|||
0
|
||||
].self_attn.q_proj.weight
|
||||
|
||||
assert torch.allclose(first_weight_old, first_weight)
|
||||
|
||||
# merge weights - new merging method from peft
|
||||
lora_model = lora_model.merge_and_unload()
|
||||
|
||||
lora_model.train(False)
|
||||
|
||||
# did we do anything?
|
||||
assert not torch.allclose(first_weight_old, first_weight)
|
||||
|
||||
lora_model_sd = lora_model.state_dict()
|
||||
deloreanized_sd = {
|
||||
k.replace("base_model.model.", ""): v
|
||||
|
|
|
|||
|
|
@ -19,10 +19,10 @@ import os
|
|||
|
||||
import transformers
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
from peft import LoraConfig
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training, \
|
||||
cast_lora_weight
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from datasets import load_dataset
|
||||
import argparse
|
||||
|
|
@ -61,6 +61,8 @@ if __name__ == "__main__":
|
|||
task_type="CAUSAL_LM"
|
||||
)
|
||||
model = get_peft_model(model, config)
|
||||
cast_lora_weight(model, torch.bfloat16)
|
||||
|
||||
tokenizer.pad_token_id = 0
|
||||
tokenizer.padding_side = "left"
|
||||
trainer = transformers.Trainer(
|
||||
|
|
@ -73,7 +75,6 @@ if __name__ == "__main__":
|
|||
max_steps=200,
|
||||
learning_rate=2e-5,
|
||||
save_steps=100,
|
||||
# fp16=True,
|
||||
bf16=True, # bf16 is more stable in training
|
||||
logging_steps=20,
|
||||
output_dir="outputs",
|
||||
|
|
|
|||
|
|
@ -87,7 +87,10 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
|
|||
|
||||
def forward(self, x: torch.Tensor):
|
||||
autocast_dtype = get_autocast_dtype(x)
|
||||
if autocast_dtype is not None:
|
||||
if x.device.type == "xpu":
|
||||
# force to use bf16 on gpu
|
||||
x = x.to(torch.bfloat16)
|
||||
elif autocast_dtype is not None:
|
||||
x = x.to(autocast_dtype)
|
||||
result = super().forward(x)
|
||||
|
||||
|
|
@ -95,7 +98,7 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
|
|||
return result
|
||||
elif self.r[self.active_adapter] > 0:
|
||||
result = result.clone()
|
||||
if autocast_dtype is None:
|
||||
if autocast_dtype is None and x.device.type == "cpu":
|
||||
expected_dtype = result.dtype
|
||||
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
|
||||
output = (
|
||||
|
|
@ -357,3 +360,15 @@ Accelerator._prepare_ipex = patch_prepare_ipex
|
|||
# patch transformer for xpu DDP traing
|
||||
from transformers import TrainingArguments
|
||||
TrainingArguments._setup_devices = _setup_devices
|
||||
|
||||
|
||||
def cast_lora_weight(model, dtype=torch.bfloat16):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoraLayer):
|
||||
module = module.to(dtype)
|
||||
if 'norm' in name:
|
||||
module = module.to(torch.float32)
|
||||
if 'lm_head' in name or 'embed_tokens' in name:
|
||||
if hasattr(module, 'weight'):
|
||||
if module.weight.dtype == torch.float32:
|
||||
module = module.to(dtype)
|
||||
|
|
|
|||
Loading…
Reference in a new issue