From 0f82b8c3a0dd8676754181b2df36166887560a34 Mon Sep 17 00:00:00 2001 From: Ruonan Wang <105281011+rnwang04@users.noreply.github.com> Date: Wed, 15 Nov 2023 09:24:15 +0800 Subject: [PATCH] LLM: update qlora example (#9454) * update qlora example * fix loss=0 --- python/llm/example/GPU/QLoRA-FineTuning/README.md | 2 +- .../llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/llm/example/GPU/QLoRA-FineTuning/README.md b/python/llm/example/GPU/QLoRA-FineTuning/README.md index 891abe9e..14667b57 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/README.md +++ b/python/llm/example/GPU/QLoRA-FineTuning/README.md @@ -17,7 +17,7 @@ conda activate llm # below command will install intel_extension_for_pytorch==2.0.110+xpu as default # you can install specific ipex/torch version for your need pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu -pip install transformers==4.34.0 +pip install datasets transformers==4.34.0 pip install peft==0.5.0 pip install accelerate==0.23.0 ``` diff --git a/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py b/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py index 9c8e2d1f..36c94659 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py +++ b/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py @@ -48,7 +48,9 @@ if __name__ == "__main__": torch_dtype=torch.float16, modules_to_not_convert=["lm_head"],) model = model.to('xpu') - model.gradient_checkpointing_enable() + # Enable gradient_checkpointing if your memory is not enough, + # it will slowdown the training speed + # model.gradient_checkpointing_enable() model = prepare_model_for_kbit_training(model) config = LoraConfig( r=8, @@ -69,9 +71,10 @@ if __name__ == "__main__": gradient_accumulation_steps= 1, warmup_steps=20, max_steps=200, - learning_rate=2e-4, + learning_rate=2e-5, save_steps=100, - fp16=True, + # fp16=True, + bf16=True, # bf16 is more stable in training logging_steps=20, output_dir="outputs", optim="adamw_hf", # paged_adamw_8bit is not supported yet