LLM: GPU QLoRA update to bf16 to accelerate gradient checkpointing (#9499)
* update to bf16 to accelerate gradient checkpoint * add utils and fix ut
This commit is contained in:
parent
3e39828420
commit
076d106ef5
4 changed files with 27 additions and 11 deletions
|
|
@ -51,7 +51,8 @@ import intel_extension_for_pytorch as ipex
|
||||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
|
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
|
||||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
|
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
|
||||||
|
cast_lora_weight
|
||||||
|
|
||||||
def get_int_from_env(env_keys, default):
|
def get_int_from_env(env_keys, default):
|
||||||
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
||||||
|
|
@ -76,6 +77,7 @@ def train(
|
||||||
data_path: str = "yahma/alpaca-cleaned",
|
data_path: str = "yahma/alpaca-cleaned",
|
||||||
output_dir: str = "./bigdl-qlora-alpaca",
|
output_dir: str = "./bigdl-qlora-alpaca",
|
||||||
# training hyperparams
|
# training hyperparams
|
||||||
|
bf16: bool = True, # default to bf16
|
||||||
batch_size: int = 128,
|
batch_size: int = 128,
|
||||||
micro_batch_size: int = 2, # default to be 2, limited by GPU memory
|
micro_batch_size: int = 2, # default to be 2, limited by GPU memory
|
||||||
num_epochs: int = 3,
|
num_epochs: int = 3,
|
||||||
|
|
@ -301,6 +303,9 @@ def train(
|
||||||
# model.is_parallelizable = True
|
# model.is_parallelizable = True
|
||||||
# model.model_parallel = True
|
# model.model_parallel = True
|
||||||
|
|
||||||
|
if bf16:
|
||||||
|
cast_lora_weight(model, torch.bfloat16)
|
||||||
|
|
||||||
trainer = transformers.Trainer(
|
trainer = transformers.Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
train_dataset=train_data,
|
train_dataset=train_data,
|
||||||
|
|
|
||||||
|
|
@ -72,16 +72,11 @@ if __name__ == "__main__":
|
||||||
0
|
0
|
||||||
].self_attn.q_proj.weight
|
].self_attn.q_proj.weight
|
||||||
|
|
||||||
assert torch.allclose(first_weight_old, first_weight)
|
|
||||||
|
|
||||||
# merge weights - new merging method from peft
|
# merge weights - new merging method from peft
|
||||||
lora_model = lora_model.merge_and_unload()
|
lora_model = lora_model.merge_and_unload()
|
||||||
|
|
||||||
lora_model.train(False)
|
lora_model.train(False)
|
||||||
|
|
||||||
# did we do anything?
|
|
||||||
assert not torch.allclose(first_weight_old, first_weight)
|
|
||||||
|
|
||||||
lora_model_sd = lora_model.state_dict()
|
lora_model_sd = lora_model.state_dict()
|
||||||
deloreanized_sd = {
|
deloreanized_sd = {
|
||||||
k.replace("base_model.model.", ""): v
|
k.replace("base_model.model.", ""): v
|
||||||
|
|
|
||||||
|
|
@ -19,10 +19,10 @@ import os
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import LlamaTokenizer
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
from peft import LoraConfig
|
from peft import LoraConfig
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
|
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training, \
|
||||||
|
cast_lora_weight
|
||||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
import argparse
|
import argparse
|
||||||
|
|
@ -61,6 +61,8 @@ if __name__ == "__main__":
|
||||||
task_type="CAUSAL_LM"
|
task_type="CAUSAL_LM"
|
||||||
)
|
)
|
||||||
model = get_peft_model(model, config)
|
model = get_peft_model(model, config)
|
||||||
|
cast_lora_weight(model, torch.bfloat16)
|
||||||
|
|
||||||
tokenizer.pad_token_id = 0
|
tokenizer.pad_token_id = 0
|
||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
trainer = transformers.Trainer(
|
trainer = transformers.Trainer(
|
||||||
|
|
@ -73,7 +75,6 @@ if __name__ == "__main__":
|
||||||
max_steps=200,
|
max_steps=200,
|
||||||
learning_rate=2e-5,
|
learning_rate=2e-5,
|
||||||
save_steps=100,
|
save_steps=100,
|
||||||
# fp16=True,
|
|
||||||
bf16=True, # bf16 is more stable in training
|
bf16=True, # bf16 is more stable in training
|
||||||
logging_steps=20,
|
logging_steps=20,
|
||||||
output_dir="outputs",
|
output_dir="outputs",
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,10 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
autocast_dtype = get_autocast_dtype(x)
|
autocast_dtype = get_autocast_dtype(x)
|
||||||
if autocast_dtype is not None:
|
if x.device.type == "xpu":
|
||||||
|
# force to use bf16 on gpu
|
||||||
|
x = x.to(torch.bfloat16)
|
||||||
|
elif autocast_dtype is not None:
|
||||||
x = x.to(autocast_dtype)
|
x = x.to(autocast_dtype)
|
||||||
result = super().forward(x)
|
result = super().forward(x)
|
||||||
|
|
||||||
|
|
@ -95,7 +98,7 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
|
||||||
return result
|
return result
|
||||||
elif self.r[self.active_adapter] > 0:
|
elif self.r[self.active_adapter] > 0:
|
||||||
result = result.clone()
|
result = result.clone()
|
||||||
if autocast_dtype is None:
|
if autocast_dtype is None and x.device.type == "cpu":
|
||||||
expected_dtype = result.dtype
|
expected_dtype = result.dtype
|
||||||
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
|
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
|
||||||
output = (
|
output = (
|
||||||
|
|
@ -357,3 +360,15 @@ Accelerator._prepare_ipex = patch_prepare_ipex
|
||||||
# patch transformer for xpu DDP traing
|
# patch transformer for xpu DDP traing
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
TrainingArguments._setup_devices = _setup_devices
|
TrainingArguments._setup_devices = _setup_devices
|
||||||
|
|
||||||
|
|
||||||
|
def cast_lora_weight(model, dtype=torch.bfloat16):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, LoraLayer):
|
||||||
|
module = module.to(dtype)
|
||||||
|
if 'norm' in name:
|
||||||
|
module = module.to(torch.float32)
|
||||||
|
if 'lm_head' in name or 'embed_tokens' in name:
|
||||||
|
if hasattr(module, 'weight'):
|
||||||
|
if module.weight.dtype == torch.float32:
|
||||||
|
module = module.to(dtype)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue