fix dpo finetune (#12774)
This commit is contained in:
parent
9697197f3e
commit
2e5f2e5dda
2 changed files with 9 additions and 11 deletions
|
|
@ -17,11 +17,9 @@ conda create -n llm python=3.11
|
|||
conda activate llm
|
||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
pip install datasets
|
||||
pip install transformers==4.45.0 "trl<0.12.0" datasets
|
||||
pip install peft==0.10.0
|
||||
pip install 'trl<0.9'
|
||||
# Note, if you don't want to reinstall BNBs dependencies, append the `--no-deps` flag!
|
||||
pip install --no-deps --force-reinstall 'https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_multi-backend-refactor/bitsandbytes-0.44.1.dev0-py3-none-manylinux_2_24_x86_64.whl'
|
||||
pip install bitsandbytes==0.45.1
|
||||
```
|
||||
|
||||
### 2. Configures OneAPI environment variables
|
||||
|
|
|
|||
|
|
@ -37,10 +37,10 @@ import torch
|
|||
from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
|
||||
from ipex_llm.transformers import AutoModelForCausalLM
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, TrainingArguments, BitsAndBytesConfig
|
||||
from transformers import AutoTokenizer, BitsAndBytesConfig
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from trl import DPOTrainer
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
import argparse
|
||||
|
||||
|
||||
|
|
@ -83,7 +83,7 @@ if __name__ == "__main__":
|
|||
dataset_path = args.dataset
|
||||
output_path = args.output_path
|
||||
gradient_checkpointing = args.gradient_checkpointing
|
||||
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset(dataset_path)['train']
|
||||
|
||||
|
|
@ -143,12 +143,15 @@ if __name__ == "__main__":
|
|||
ref_model = ref_model.to('xpu')
|
||||
|
||||
# Training arguments
|
||||
training_args = TrainingArguments(
|
||||
training_args = DPOConfig(
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
learning_rate=5e-5,
|
||||
lr_scheduler_type="cosine",
|
||||
beta=0.1,
|
||||
max_prompt_length=1024,
|
||||
max_length=1536,
|
||||
max_steps=200,
|
||||
save_strategy="no",
|
||||
logging_steps=1,
|
||||
|
|
@ -166,9 +169,6 @@ if __name__ == "__main__":
|
|||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
beta=0.1,
|
||||
max_prompt_length=1024,
|
||||
max_length=1536,
|
||||
)
|
||||
|
||||
# Fine-tune model with DPO
|
||||
|
|
|
|||
Loading…
Reference in a new issue