ipex-llm/python/llm/example/GPU/LLM-Finetuning/DPO/dpo_finetuning.py
Wang, Jian4 9df70d95eb
Refactor bigdl.llm to ipex_llm (#24)
* Rename bigdl/llm to ipex_llm

* rm python/llm/src/bigdl

* from bigdl.llm to from ipex_llm
2024-03-22 15:41:21 +08:00

179 lines
6.4 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb
#
# Copyright [yyyy] [name of copyright owner]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import torch
import transformers
from transformers import AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from datasets import load_dataset
from peft import LoraConfig
from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
from ipex_llm.transformers import AutoModelForCausalLM
from trl import DPOTrainer
import argparse
def chatml_format(example):
# Format system
if len(example['system']) > 0:
message = {"role": "system", "content": example['system']}
system = tokenizer.apply_chat_template([message], tokenize=False)
else:
system = ""
# Format instruction
message = {"role": "user", "content": example['question']}
prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
# Format chosen answer
chosen = example['chosen'] + "<|im_end|>\n"
# Format rejected answer
rejected = example['rejected'] + "<|im_end|>\n"
return {
"prompt": system + prompt,
"chosen": chosen,
"rejected": rejected,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Finetune a Mistral-7b model with DPO')
parser.add_argument('--repo-id-or-model-path', type=str, default="teknium/OpenHermes-2.5-Mistral-7B",
help='The huggingface repo id for the Mistral (e.g. `teknium/OpenHermes-2.5-Mistral-7B`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--dataset', type=str, default="Intel/orca_dpo_pairs")
parser.add_argument('--output-path', type=str, default="outputs")
parser.add_argument('--gradient-checkpointing', action='store_true', help='Whether to enable gradient checkpointing to save memory at the expense of slower backward pass.')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
dataset_path = args.dataset
output_path = args.output_path
gradient_checkpointing = args.gradient_checkpointing
# Load dataset
dataset = load_dataset(dataset_path)['train']
# Save columns
original_columns = dataset.column_names
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
# Format dataset
dataset = dataset.map(
chatml_format,
remove_columns=original_columns
)
# LoRA configuration
peft_config = LoraConfig(
r=16,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(model_path,
quantization_config=bnb_config, )
# below is also supported
# model = AutoModelForCausalLM.from_pretrained(model_path,
# load_in_low_bit="nf4",
# optimize_model=False,
# torch_dtype=torch.bfloat16,
# modules_to_not_convert=["lm_head"],)
model = model.to('xpu')
# Prepare a BigDL-LLM compatible Peft model
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=gradient_checkpointing)
model = get_peft_model(model, peft_config)
model.config.use_cache = False
model.print_trainable_parameters()
# Reference model, same as the main one
ref_model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit="nf4",
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],)
ref_model = ref_model.to('xpu')
# Training arguments
training_args = TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
gradient_checkpointing=gradient_checkpointing,
learning_rate=5e-5,
lr_scheduler_type="cosine",
max_steps=200,
save_strategy="no",
logging_steps=1,
output_dir=output_path,
# optim="paged_adamw_32bit", # "paged_adamw_32bit" is not supported yet
optim="adamw_hf",
warmup_steps=100,
bf16=True,
)
# Create DPO trainer
dpo_trainer = DPOTrainer(
model,
ref_model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
beta=0.1,
max_prompt_length=1024,
max_length=1536,
)
# Fine-tune model with DPO
dpo_trainer.train()
# Save artifacts
dpo_trainer.model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)