From 9975b029c56e4f983fd9d99737a23756cb9567b3 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Wed, 21 Feb 2024 16:40:04 +0800 Subject: [PATCH] LLM: add qlora finetuning example using `trl.SFTTrainer` (#10183) --- .../GPU/LLM-Finetuning/QLoRA/README.md | 2 +- .../QLoRA/trl-example/README.md | 55 +++++++++++ .../QLoRA/trl-example/export_merged_model.py | 44 +++++++++ .../QLoRA/trl-example/qlora_finetuning.py | 94 +++++++++++++++++++ 4 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/README.md create mode 100644 python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/export_merged_model.py create mode 100644 python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/qlora_finetuning.py diff --git a/python/llm/example/GPU/LLM-Finetuning/QLoRA/README.md b/python/llm/example/GPU/LLM-Finetuning/QLoRA/README.md index e5ad815c..2afebf3d 100644 --- a/python/llm/example/GPU/LLM-Finetuning/QLoRA/README.md +++ b/python/llm/example/GPU/LLM-Finetuning/QLoRA/README.md @@ -2,4 +2,4 @@ We provide [Alpaca-QLoRA example](./alpaca-qlora/), which ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM (using [QLoRA](https://arxiv.org/abs/2305.14314) algorithm) on [Intel GPU](../../README.md). -Meanwhile, we also provide a [simple example](./simple-example/) to help you get started with QLoRA Finetuning using BigDL-LLM. +Meanwhile, we also provide a [simple example](./simple-example/) to help you get started with QLoRA Finetuning using BigDL-LLM, and [TRL example](./trl-example/) to help you get started with QLoRA Finetuning using BigDL-LLM and TRL library. diff --git a/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/README.md b/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/README.md new file mode 100644 index 00000000..d17ca368 --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/README.md @@ -0,0 +1,55 @@ +# Example of QLoRA Finetuning with BigDL-LLM + +This simple example demonstrates how to finetune a llama2-7b model use BigDL-LLM 4bit optimizations with TRL library on [Intel GPU](../../../README.md). +Note, this example is just used for illustrating related usage and don't guarantee convergence of training. + +## 0. Requirements +To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../../README.md#requirements) for more information. + +## Example: Finetune llama2-7b using qlora + +The `export_merged_model.py` is ported from [alpaca-lora](https://github.com/tloen/alpaca-lora/blob/main/export_hf_checkpoint.py). + +### 1. Install + +```bash +conda create -n llm python=3.9 +conda activate llm +# below command will install intel_extension_for_pytorch==2.1.10+xpu as default +pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu +pip install transformers==4.34.0 datasets +pip install peft==0.5.0 +pip install accelerate==0.23.0 +pip install bitsandbytes scipy trl +``` + +### 2. Configures OneAPI environment variables +```bash +source /opt/intel/oneapi/setvars.sh +``` + +### 3. Finetune model + +``` +python ./qlora_finetuning.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH +``` + +#### Sample Output +```log +{'loss': 1.7386, 'learning_rate': 8.888888888888888e-06, 'epoch': 0.19} +{'loss': 1.9242, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.22} +{'loss': 1.6819, 'learning_rate': 4.444444444444444e-06, 'epoch': 0.26} +{'loss': 1.755, 'learning_rate': 2.222222222222222e-06, 'epoch': 0.29} +{'loss': 1.7455, 'learning_rate': 0.0, 'epoch': 0.32} +{'train_runtime': 172.8523, 'train_samples_per_second': 4.628, 'train_steps_per_second': 1.157, 'train_loss': 1.9101631927490235, 'epoch': 0.32} +100%|████████████████████████████████████████████| 200/200 [02:52<00:00, 1.16it/s] +TrainOutput(global_step=200, training_loss=1.9101631927490235, metrics={'train_runtime': 172.8523, 'train_samples_per_second': 4.628, 'train_steps_per_second': 1.157, 'train_loss': 1.9101631927490235, 'epoch': 0.32}) +``` + +### 4. Merge the adapter into the original model + +``` +python ./export_merged_model.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --adapter_path ./outputs/checkpoint-200 --output_path ./outputs/checkpoint-200-merged +``` + +Then you can use `./outputs/checkpoint-200-merged` as a normal huggingface transformer model to do inference. diff --git a/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/export_merged_model.py b/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/export_merged_model.py new file mode 100644 index 00000000..80902312 --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/export_merged_model.py @@ -0,0 +1,44 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +import torch +from transformers import LlamaTokenizer # noqa: F402 +import argparse + +current_dir = os.path.dirname(os.path.realpath(__file__)) +common_util_path = os.path.join(current_dir, '..', '..') +import sys +sys.path.append(common_util_path) +from common.utils import merge_adapter + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Merge the adapter into the original model for Llama2 model') + parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-hf", + help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--adapter_path', type=str,) + parser.add_argument('--output_path', type=str,) + + args = parser.parse_args() + base_model = model_path = args.repo_id_or_model_path + adapter_path = args.adapter_path + output_path = args.output_path + + tokenizer = LlamaTokenizer.from_pretrained(base_model) + merge_adapter(base_model, tokenizer, adapter_path, output_path) + print(f'Finish to merge the adapter into the original model and you could find the merged model in {output_path}.') diff --git a/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/qlora_finetuning.py b/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/qlora_finetuning.py new file mode 100644 index 00000000..eb34db48 --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/QLoRA/trl-example/qlora_finetuning.py @@ -0,0 +1,94 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import os + +import transformers +from transformers import LlamaTokenizer +from peft import LoraConfig +from transformers import BitsAndBytesConfig +from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training +from bigdl.llm.transformers import AutoModelForCausalLM +from datasets import load_dataset +from trl import SFTTrainer +import argparse + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Simple example of how to qlora finetune llama2 model using bigdl-llm and TRL') + parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-hf", + help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--dataset', type=str, default="Abirate/english_quotes") + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + dataset_path = args.dataset + tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) + + data = load_dataset(dataset_path, split="train") + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16 + ) + model = AutoModelForCausalLM.from_pretrained(model_path, + quantization_config=bnb_config, ) + + # below is also supported + # model = AutoModelForCausalLM.from_pretrained(model_path, + # load_in_low_bit="nf4", + # optimize_model=False, + # torch_dtype=torch.bfloat16, + # modules_to_not_convert=["lm_head"],) + model = model.to('xpu') + # Enable gradient_checkpointing if your memory is not enough, + # it will slowdown the training speed + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) + config = LoraConfig( + r=8, + lora_alpha=32, + target_modules=["q_proj", "k_proj", "v_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, config) + + trainer = SFTTrainer( + model=model, + train_dataset=data, + args=transformers.TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps= 1, + warmup_steps=20, + max_steps=200, + learning_rate=2e-5, + save_steps=100, + bf16=True, # bf16 is more stable in training + logging_steps=20, + output_dir="outputs", + optim="adamw_hf", # paged_adamw_8bit is not supported yet + gradient_checkpointing=True, # can further reduce memory but slower + ), + dataset_text_field="quote", + ) + model.config.use_cache = False # silence the warnings. Please re-enable for inference! + result = trainer.train() + print(result)