diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py index f276b801..c08abbde 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py @@ -51,7 +51,8 @@ import intel_extension_for_pytorch as ipex from bigdl.llm.transformers import AutoModelForCausalLM # import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model -from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training +from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\ + cast_lora_weight def get_int_from_env(env_keys, default): """Returns the first positive env value found in the `env_keys` list or the default.""" @@ -76,6 +77,7 @@ def train( data_path: str = "yahma/alpaca-cleaned", output_dir: str = "./bigdl-qlora-alpaca", # training hyperparams + bf16: bool = True, # default to bf16 batch_size: int = 128, micro_batch_size: int = 2, # default to be 2, limited by GPU memory num_epochs: int = 3, @@ -301,6 +303,9 @@ def train( # model.is_parallelizable = True # model.model_parallel = True + if bf16: + cast_lora_weight(model, torch.bfloat16) + trainer = transformers.Trainer( model=model, train_dataset=train_data, diff --git a/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py b/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py index 1cf3c2ff..2df8fee1 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py +++ b/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py @@ -72,16 +72,11 @@ if __name__ == "__main__": 0 ].self_attn.q_proj.weight - assert torch.allclose(first_weight_old, first_weight) - # merge weights - new merging method from peft lora_model = lora_model.merge_and_unload() lora_model.train(False) - # did we do anything? - assert not torch.allclose(first_weight_old, first_weight) - lora_model_sd = lora_model.state_dict() deloreanized_sd = { k.replace("base_model.model.", ""): v diff --git a/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py b/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py index 36c94659..6e18bcee 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py +++ b/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py @@ -19,10 +19,10 @@ import os import transformers from transformers import LlamaTokenizer - from peft import LoraConfig import intel_extension_for_pytorch as ipex -from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training +from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training, \ + cast_lora_weight from bigdl.llm.transformers import AutoModelForCausalLM from datasets import load_dataset import argparse @@ -61,6 +61,8 @@ if __name__ == "__main__": task_type="CAUSAL_LM" ) model = get_peft_model(model, config) + cast_lora_weight(model, torch.bfloat16) + tokenizer.pad_token_id = 0 tokenizer.padding_side = "left" trainer = transformers.Trainer( @@ -73,7 +75,6 @@ if __name__ == "__main__": max_steps=200, learning_rate=2e-5, save_steps=100, - # fp16=True, bf16=True, # bf16 is more stable in training logging_steps=20, output_dir="outputs", diff --git a/python/llm/src/bigdl/llm/transformers/qlora.py b/python/llm/src/bigdl/llm/transformers/qlora.py index 35b1f0ef..18a0cc5f 100644 --- a/python/llm/src/bigdl/llm/transformers/qlora.py +++ b/python/llm/src/bigdl/llm/transformers/qlora.py @@ -87,7 +87,10 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer): def forward(self, x: torch.Tensor): autocast_dtype = get_autocast_dtype(x) - if autocast_dtype is not None: + if x.device.type == "xpu": + # force to use bf16 on gpu + x = x.to(torch.bfloat16) + elif autocast_dtype is not None: x = x.to(autocast_dtype) result = super().forward(x) @@ -95,7 +98,7 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer): return result elif self.r[self.active_adapter] > 0: result = result.clone() - if autocast_dtype is None: + if autocast_dtype is None and x.device.type == "cpu": expected_dtype = result.dtype x = x.to(self.lora_A[self.active_adapter].weight.dtype) output = ( @@ -357,3 +360,15 @@ Accelerator._prepare_ipex = patch_prepare_ipex # patch transformer for xpu DDP traing from transformers import TrainingArguments TrainingArguments._setup_devices = _setup_devices + + +def cast_lora_weight(model, dtype=torch.bfloat16): + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + module = module.to(dtype) + if 'norm' in name: + module = module.to(torch.float32) + if 'lm_head' in name or 'embed_tokens' in name: + if hasattr(module, 'weight'): + if module.weight.dtype == torch.float32: + module = module.to(dtype)