LLM: Add initial DPO finetuning example (#10021)

This commit is contained in:
binbin Deng 2024-02-01 14:18:08 +08:00 committed by GitHub
parent 601024f418
commit aae20d728e
4 changed files with 280 additions and 0 deletions

View file

@ -0,0 +1,56 @@
# Simple Example of DPO Finetuning with BigDL-LLM
This simple example demonstrates how to finetune a Mistral-7B model use BigDL-LLM 4bit optimizations using [Intel GPUs](../../README.md).
Note, this example is just used for illustrating related usage.
## 0. Requirements
To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../README.md#requirements) for more information.
## Example: Finetune Mistral-7b using DPO
This example is ported from [Fine_tune_a_Mistral_7b_model_with_DPO](https://github.com/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb).
### 1. Install
```bash
conda create -n llm python=3.9
conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
pip install transformers==4.34.0 datasets
pip install trl peft==0.5.0
pip install accelerate==0.23.0
pip install bitsandbytes
```
### 2. Configures OneAPI environment variables
```bash
source /opt/intel/oneapi/setvars.sh
```
### 3. Finetune model
```
python ./dpo_finetuning.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --gradient-checkpointing
```
> Note: The final LoRA weights and configurations are saved to './outputs' by default. You could also change the output path through specifying `--output-path`.
#### Sample Output
```log
trainable params: 41,943,040 || all params: 4,012,134,400 || trainable%: 1.0454046604221434
{'loss': 0.6931, 'learning_rate': 5.000000000000001e-07, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/rejected': -271.842041015625, 'logps/chosen': -146.93634033203125, 'logits/rejected': -2.9851596355438232, 'logits/chosen': -2.98481822013855, 'epoch': 0.0}
{'loss': 0.6931, 'learning_rate': 1.0000000000000002e-06, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/rejected': -248.09817504882812, 'logps/chosen': -259.561767578125, 'logits/rejected': -2.967536449432373, 'logits/chosen': -2.951939582824707, 'epoch': 0.0}
{'loss': 0.7058, 'learning_rate': 1.5e-06, 'rewards/chosen': -0.006700039375573397, 'rewards/rejected': 0.016817521303892136, 'rewards/accuracies': 0.4375, 'rewards/margins': -0.023517560213804245, 'logps/rejected': -183.52743530273438, 'logps/chosen': -122.3787841796875, 'logits/rejected': -2.948030471801758, 'logits/chosen': -2.9321558475494385, 'epoch': 0.0}
{'loss': 0.6912, 'learning_rate': 2.0000000000000003e-06, 'rewards/chosen': 0.0014888052828609943, 'rewards/rejected': -0.004842948634177446, 'rewards/accuracies': 0.625, 'rewards/margins': 0.006331752985715866, 'logps/rejected': -234.07257080078125, 'logps/chosen': -181.22940063476562, 'logits/rejected': -2.938673496246338, 'logits/chosen': -2.9304277896881104, 'epoch': 0.0}
{'loss': 0.6958, 'learning_rate': 2.5e-06, 'rewards/chosen': -0.001946449396200478, 'rewards/rejected': 0.0025150063447654247, 'rewards/accuracies': 0.5625, 'rewards/margins': -0.004461456090211868, 'logps/rejected': -263.15106201171875, 'logps/chosen': -242.25759887695312, 'logits/rejected': -2.931898832321167, 'logits/chosen': -2.9180212020874023, 'epoch': 0.01}
{'loss': 0.6714, 'learning_rate': 3e-06, 'rewards/chosen': 0.002834760583937168, 'rewards/rejected': -0.043302297592163086, 'rewards/accuracies': 0.625, 'rewards/margins': 0.04613706097006798, 'logps/rejected': -269.76953125, 'logps/chosen': -175.4458465576172, 'logits/rejected': -2.863767147064209, 'logits/chosen': -2.813715696334839, 'epoch': 0.01}
{'loss': 0.6773, 'learning_rate': 3.5000000000000004e-06, 'rewards/chosen': -0.000818049069494009, 'rewards/rejected': -0.03519792854785919, 'rewards/accuracies': 0.6875, 'rewards/margins': 0.034379877150058746, 'logps/rejected': -307.48388671875, 'logps/chosen': -258.1222839355469, 'logits/rejected': -2.93851900100708, 'logits/chosen': -2.845832347869873, 'epoch': 0.01}
```
### 4. Merge the adapter into the original model
```
python ./export_merged_model.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --adapter_path ./outputs --output_path ./outputs/merged-model
```
Then you can use `./outputs/merged-model` as a normal huggingface transformer model to do inference.

View file

@ -0,0 +1,179 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb
#
# Copyright [yyyy] [name of copyright owner]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import torch
import transformers
from transformers import AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from datasets import load_dataset
from peft import LoraConfig
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
from bigdl.llm.transformers import AutoModelForCausalLM
from trl import DPOTrainer
import argparse
def chatml_format(example):
# Format system
if len(example['system']) > 0:
message = {"role": "system", "content": example['system']}
system = tokenizer.apply_chat_template([message], tokenize=False)
else:
system = ""
# Format instruction
message = {"role": "user", "content": example['question']}
prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
# Format chosen answer
chosen = example['chosen'] + "<|im_end|>\n"
# Format rejected answer
rejected = example['rejected'] + "<|im_end|>\n"
return {
"prompt": system + prompt,
"chosen": chosen,
"rejected": rejected,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Finetune a Mistral-7b model with DPO')
parser.add_argument('--repo-id-or-model-path', type=str, default="teknium/OpenHermes-2.5-Mistral-7B",
help='The huggingface repo id for the Mistral (e.g. `teknium/OpenHermes-2.5-Mistral-7B`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--dataset', type=str, default="Intel/orca_dpo_pairs")
parser.add_argument('--output-path', type=str, default="outputs")
parser.add_argument('--gradient-checkpointing', action='store_true', help='Whether to enable gradient checkpointing to save memory at the expense of slower backward pass.')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
dataset_path = args.dataset
output_path = args.output_path
gradient_checkpointing = args.gradient_checkpointing
# Load dataset
dataset = load_dataset(dataset_path)['train']
# Save columns
original_columns = dataset.column_names
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
# Format dataset
dataset = dataset.map(
chatml_format,
remove_columns=original_columns
)
# LoRA configuration
peft_config = LoraConfig(
r=16,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(model_path,
quantization_config=bnb_config, )
# below is also supported
# model = AutoModelForCausalLM.from_pretrained(model_path,
# load_in_low_bit="nf4",
# optimize_model=False,
# torch_dtype=torch.bfloat16,
# modules_to_not_convert=["lm_head"],)
model = model.to('xpu')
# Prepare a BigDL-LLM compatible Peft model
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=gradient_checkpointing)
model = get_peft_model(model, peft_config)
model.config.use_cache = False
model.print_trainable_parameters()
# Reference model, same as the main one
ref_model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit="nf4",
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],)
ref_model = ref_model.to('xpu')
# Training arguments
training_args = TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
gradient_checkpointing=gradient_checkpointing,
learning_rate=5e-5,
lr_scheduler_type="cosine",
max_steps=200,
save_strategy="no",
logging_steps=1,
output_dir=output_path,
# optim="paged_adamw_32bit", # "paged_adamw_32bit" is not supported yet
optim="adamw_hf",
warmup_steps=100,
bf16=True,
)
# Create DPO trainer
dpo_trainer = DPOTrainer(
model,
ref_model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
beta=0.1,
max_prompt_length=1024,
max_length=1536,
)
# Fine-tune model with DPO
dpo_trainer.train()
# Save artifacts
dpo_trainer.model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)

View file

@ -0,0 +1,44 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import torch
from transformers import AutoTokenizer
import argparse
current_dir = os.path.dirname(os.path.realpath(__file__))
common_util_path = os.path.join(current_dir, '..')
import sys
sys.path.append(common_util_path)
from common.utils import merge_adapter
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Merge the adapter into the original model for Mistral model')
parser.add_argument('--repo-id-or-model-path', type=str, default="teknium/OpenHermes-2.5-Mistral-7B",
help='The huggingface repo id the Mistral (e.g. `teknium/OpenHermes-2.5-Mistral-7B`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--adapter_path', type=str,)
parser.add_argument('--output_path', type=str,)
args = parser.parse_args()
base_model = model_path = args.repo_id_or_model_path
adapter_path = args.adapter_path
output_path = args.output_path
tokenizer = AutoTokenizer.from_pretrained(base_model)
merge_adapter(base_model, tokenizer, adapter_path, output_path)
print(f'Finish to merge the adapter into the original model and you could find the merged model in {output_path}.')

View file

@ -6,4 +6,5 @@ This folder contains examples of running different training mode with BigDL-LLM
- [QLoRA](QLoRA): examples of running QLoRA finetuning
- [QA-LoRA](QA-LoRA): examples of running QA-LoRA finetuning
- [ReLora](ReLora): examples of running ReLora finetuning
- [DPO](DPO): examples of running DPO finetuning
- [common](common): common templates and utility classes in finetuning examples