From aae20d728e352e27d1fce24816ac30da3ce1c77c Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Thu, 1 Feb 2024 14:18:08 +0800 Subject: [PATCH] LLM: Add initial DPO finetuning example (#10021) --- .../example/GPU/LLM-Finetuning/DPO/README.md | 56 ++++++ .../GPU/LLM-Finetuning/DPO/dpo_finetuning.py | 179 ++++++++++++++++++ .../LLM-Finetuning/DPO/export_merged_model.py | 44 +++++ .../llm/example/GPU/LLM-Finetuning/README.md | 1 + 4 files changed, 280 insertions(+) create mode 100644 python/llm/example/GPU/LLM-Finetuning/DPO/README.md create mode 100644 python/llm/example/GPU/LLM-Finetuning/DPO/dpo_finetuning.py create mode 100644 python/llm/example/GPU/LLM-Finetuning/DPO/export_merged_model.py diff --git a/python/llm/example/GPU/LLM-Finetuning/DPO/README.md b/python/llm/example/GPU/LLM-Finetuning/DPO/README.md new file mode 100644 index 00000000..3992f16c --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/DPO/README.md @@ -0,0 +1,56 @@ +# Simple Example of DPO Finetuning with BigDL-LLM + +This simple example demonstrates how to finetune a Mistral-7B model use BigDL-LLM 4bit optimizations using [Intel GPUs](../../README.md). +Note, this example is just used for illustrating related usage. + +## 0. Requirements +To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../README.md#requirements) for more information. + +## Example: Finetune Mistral-7b using DPO + +This example is ported from [Fine_tune_a_Mistral_7b_model_with_DPO](https://github.com/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb). + +### 1. Install + +```bash +conda create -n llm python=3.9 +conda activate llm +# below command will install intel_extension_for_pytorch==2.1.10+xpu as default +pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu +pip install transformers==4.34.0 datasets +pip install trl peft==0.5.0 +pip install accelerate==0.23.0 +pip install bitsandbytes +``` + +### 2. Configures OneAPI environment variables +```bash +source /opt/intel/oneapi/setvars.sh +``` + +### 3. Finetune model + +``` +python ./dpo_finetuning.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --gradient-checkpointing +``` +> Note: The final LoRA weights and configurations are saved to './outputs' by default. You could also change the output path through specifying `--output-path`. + +#### Sample Output +```log +trainable params: 41,943,040 || all params: 4,012,134,400 || trainable%: 1.0454046604221434 +{'loss': 0.6931, 'learning_rate': 5.000000000000001e-07, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/rejected': -271.842041015625, 'logps/chosen': -146.93634033203125, 'logits/rejected': -2.9851596355438232, 'logits/chosen': -2.98481822013855, 'epoch': 0.0} +{'loss': 0.6931, 'learning_rate': 1.0000000000000002e-06, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/rejected': -248.09817504882812, 'logps/chosen': -259.561767578125, 'logits/rejected': -2.967536449432373, 'logits/chosen': -2.951939582824707, 'epoch': 0.0} +{'loss': 0.7058, 'learning_rate': 1.5e-06, 'rewards/chosen': -0.006700039375573397, 'rewards/rejected': 0.016817521303892136, 'rewards/accuracies': 0.4375, 'rewards/margins': -0.023517560213804245, 'logps/rejected': -183.52743530273438, 'logps/chosen': -122.3787841796875, 'logits/rejected': -2.948030471801758, 'logits/chosen': -2.9321558475494385, 'epoch': 0.0} +{'loss': 0.6912, 'learning_rate': 2.0000000000000003e-06, 'rewards/chosen': 0.0014888052828609943, 'rewards/rejected': -0.004842948634177446, 'rewards/accuracies': 0.625, 'rewards/margins': 0.006331752985715866, 'logps/rejected': -234.07257080078125, 'logps/chosen': -181.22940063476562, 'logits/rejected': -2.938673496246338, 'logits/chosen': -2.9304277896881104, 'epoch': 0.0} +{'loss': 0.6958, 'learning_rate': 2.5e-06, 'rewards/chosen': -0.001946449396200478, 'rewards/rejected': 0.0025150063447654247, 'rewards/accuracies': 0.5625, 'rewards/margins': -0.004461456090211868, 'logps/rejected': -263.15106201171875, 'logps/chosen': -242.25759887695312, 'logits/rejected': -2.931898832321167, 'logits/chosen': -2.9180212020874023, 'epoch': 0.01} +{'loss': 0.6714, 'learning_rate': 3e-06, 'rewards/chosen': 0.002834760583937168, 'rewards/rejected': -0.043302297592163086, 'rewards/accuracies': 0.625, 'rewards/margins': 0.04613706097006798, 'logps/rejected': -269.76953125, 'logps/chosen': -175.4458465576172, 'logits/rejected': -2.863767147064209, 'logits/chosen': -2.813715696334839, 'epoch': 0.01} +{'loss': 0.6773, 'learning_rate': 3.5000000000000004e-06, 'rewards/chosen': -0.000818049069494009, 'rewards/rejected': -0.03519792854785919, 'rewards/accuracies': 0.6875, 'rewards/margins': 0.034379877150058746, 'logps/rejected': -307.48388671875, 'logps/chosen': -258.1222839355469, 'logits/rejected': -2.93851900100708, 'logits/chosen': -2.845832347869873, 'epoch': 0.01} +``` + +### 4. Merge the adapter into the original model + +``` +python ./export_merged_model.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --adapter_path ./outputs --output_path ./outputs/merged-model +``` + +Then you can use `./outputs/merged-model` as a normal huggingface transformer model to do inference. diff --git a/python/llm/example/GPU/LLM-Finetuning/DPO/dpo_finetuning.py b/python/llm/example/GPU/LLM-Finetuning/DPO/dpo_finetuning.py new file mode 100644 index 00000000..3d0708b5 --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/DPO/dpo_finetuning.py @@ -0,0 +1,179 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb +# +# Copyright [yyyy] [name of copyright owner] + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import torch + +import transformers +from transformers import AutoTokenizer, TrainingArguments, BitsAndBytesConfig +from datasets import load_dataset +from peft import LoraConfig +from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training +from bigdl.llm.transformers import AutoModelForCausalLM +from trl import DPOTrainer +import argparse + + +def chatml_format(example): + # Format system + if len(example['system']) > 0: + message = {"role": "system", "content": example['system']} + system = tokenizer.apply_chat_template([message], tokenize=False) + else: + system = "" + + # Format instruction + message = {"role": "user", "content": example['question']} + prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True) + + # Format chosen answer + chosen = example['chosen'] + "<|im_end|>\n" + + # Format rejected answer + rejected = example['rejected'] + "<|im_end|>\n" + + return { + "prompt": system + prompt, + "chosen": chosen, + "rejected": rejected, + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Finetune a Mistral-7b model with DPO') + parser.add_argument('--repo-id-or-model-path', type=str, default="teknium/OpenHermes-2.5-Mistral-7B", + help='The huggingface repo id for the Mistral (e.g. `teknium/OpenHermes-2.5-Mistral-7B`) to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--dataset', type=str, default="Intel/orca_dpo_pairs") + parser.add_argument('--output-path', type=str, default="outputs") + parser.add_argument('--gradient-checkpointing', action='store_true', help='Whether to enable gradient checkpointing to save memory at the expense of slower backward pass.') + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + dataset_path = args.dataset + output_path = args.output_path + gradient_checkpointing = args.gradient_checkpointing + + # Load dataset + dataset = load_dataset(dataset_path)['train'] + + # Save columns + original_columns = dataset.column_names + + # Tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + # Format dataset + dataset = dataset.map( + chatml_format, + remove_columns=original_columns + ) + + # LoRA configuration + peft_config = LoraConfig( + r=16, + lora_alpha=16, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj'] + ) + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16 + ) + model = AutoModelForCausalLM.from_pretrained(model_path, + quantization_config=bnb_config, ) + + # below is also supported + # model = AutoModelForCausalLM.from_pretrained(model_path, + # load_in_low_bit="nf4", + # optimize_model=False, + # torch_dtype=torch.bfloat16, + # modules_to_not_convert=["lm_head"],) + + model = model.to('xpu') + # Prepare a BigDL-LLM compatible Peft model + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=gradient_checkpointing) + model = get_peft_model(model, peft_config) + model.config.use_cache = False + model.print_trainable_parameters() + + # Reference model, same as the main one + ref_model = AutoModelForCausalLM.from_pretrained(model_path, + load_in_low_bit="nf4", + optimize_model=False, + torch_dtype=torch.bfloat16, + modules_to_not_convert=["lm_head"],) + ref_model = ref_model.to('xpu') + + # Training arguments + training_args = TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + gradient_checkpointing=gradient_checkpointing, + learning_rate=5e-5, + lr_scheduler_type="cosine", + max_steps=200, + save_strategy="no", + logging_steps=1, + output_dir=output_path, + # optim="paged_adamw_32bit", # "paged_adamw_32bit" is not supported yet + optim="adamw_hf", + warmup_steps=100, + bf16=True, + ) + + # Create DPO trainer + dpo_trainer = DPOTrainer( + model, + ref_model, + args=training_args, + train_dataset=dataset, + tokenizer=tokenizer, + beta=0.1, + max_prompt_length=1024, + max_length=1536, + ) + + # Fine-tune model with DPO + dpo_trainer.train() + + # Save artifacts + dpo_trainer.model.save_pretrained(output_path) + tokenizer.save_pretrained(output_path) diff --git a/python/llm/example/GPU/LLM-Finetuning/DPO/export_merged_model.py b/python/llm/example/GPU/LLM-Finetuning/DPO/export_merged_model.py new file mode 100644 index 00000000..910a0513 --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/DPO/export_merged_model.py @@ -0,0 +1,44 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +import torch +from transformers import AutoTokenizer +import argparse + +current_dir = os.path.dirname(os.path.realpath(__file__)) +common_util_path = os.path.join(current_dir, '..') +import sys +sys.path.append(common_util_path) +from common.utils import merge_adapter + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Merge the adapter into the original model for Mistral model') + parser.add_argument('--repo-id-or-model-path', type=str, default="teknium/OpenHermes-2.5-Mistral-7B", + help='The huggingface repo id the Mistral (e.g. `teknium/OpenHermes-2.5-Mistral-7B`) to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--adapter_path', type=str,) + parser.add_argument('--output_path', type=str,) + + args = parser.parse_args() + base_model = model_path = args.repo_id_or_model_path + adapter_path = args.adapter_path + output_path = args.output_path + + tokenizer = AutoTokenizer.from_pretrained(base_model) + merge_adapter(base_model, tokenizer, adapter_path, output_path) + print(f'Finish to merge the adapter into the original model and you could find the merged model in {output_path}.') diff --git a/python/llm/example/GPU/LLM-Finetuning/README.md b/python/llm/example/GPU/LLM-Finetuning/README.md index c8d59c39..114a73cb 100644 --- a/python/llm/example/GPU/LLM-Finetuning/README.md +++ b/python/llm/example/GPU/LLM-Finetuning/README.md @@ -6,4 +6,5 @@ This folder contains examples of running different training mode with BigDL-LLM - [QLoRA](QLoRA): examples of running QLoRA finetuning - [QA-LoRA](QA-LoRA): examples of running QA-LoRA finetuning - [ReLora](ReLora): examples of running ReLora finetuning +- [DPO](DPO): examples of running DPO finetuning - [common](common): common templates and utility classes in finetuning examples