fix dpo finetune (#12774)

This commit is contained in:
Yishuo Wang 2025-02-06 16:35:21 +08:00 committed by GitHub
parent 9697197f3e
commit 2e5f2e5dda
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 11 deletions

View file

@ -17,11 +17,9 @@ conda create -n llm python=3.11
conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install datasets
pip install transformers==4.45.0 "trl<0.12.0" datasets
pip install peft==0.10.0
pip install 'trl<0.9'
# Note, if you don't want to reinstall BNBs dependencies, append the `--no-deps` flag!
pip install --no-deps --force-reinstall 'https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_multi-backend-refactor/bitsandbytes-0.44.1.dev0-py3-none-manylinux_2_24_x86_64.whl'
pip install bitsandbytes==0.45.1
```
### 2. Configures OneAPI environment variables

View file

@ -37,10 +37,10 @@ import torch
from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
from ipex_llm.transformers import AutoModelForCausalLM
import transformers
from transformers import AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from transformers import AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
from peft import LoraConfig
from trl import DPOTrainer
from trl import DPOConfig, DPOTrainer
import argparse
@ -83,7 +83,7 @@ if __name__ == "__main__":
dataset_path = args.dataset
output_path = args.output_path
gradient_checkpointing = args.gradient_checkpointing
# Load dataset
dataset = load_dataset(dataset_path)['train']
@ -143,12 +143,15 @@ if __name__ == "__main__":
ref_model = ref_model.to('xpu')
# Training arguments
training_args = TrainingArguments(
training_args = DPOConfig(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
gradient_checkpointing=gradient_checkpointing,
learning_rate=5e-5,
lr_scheduler_type="cosine",
beta=0.1,
max_prompt_length=1024,
max_length=1536,
max_steps=200,
save_strategy="no",
logging_steps=1,
@ -166,9 +169,6 @@ if __name__ == "__main__":
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
beta=0.1,
max_prompt_length=1024,
max_length=1536,
)
# Fine-tune model with DPO