fix dpo finetune (#12774)
This commit is contained in:
		
							parent
							
								
									9697197f3e
								
							
						
					
					
						commit
						2e5f2e5dda
					
				
					 2 changed files with 9 additions and 11 deletions
				
			
		| 
						 | 
					@ -17,11 +17,9 @@ conda create -n llm python=3.11
 | 
				
			||||||
conda activate llm
 | 
					conda activate llm
 | 
				
			||||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
					# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
				
			||||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
					pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
				
			||||||
pip install datasets
 | 
					pip install transformers==4.45.0 "trl<0.12.0" datasets
 | 
				
			||||||
pip install peft==0.10.0
 | 
					pip install peft==0.10.0
 | 
				
			||||||
pip install 'trl<0.9'
 | 
					pip install bitsandbytes==0.45.1
 | 
				
			||||||
# Note, if you don't want to reinstall BNBs dependencies, append the `--no-deps` flag!
 | 
					 | 
				
			||||||
pip install --no-deps --force-reinstall 'https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_multi-backend-refactor/bitsandbytes-0.44.1.dev0-py3-none-manylinux_2_24_x86_64.whl'
 | 
					 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### 2. Configures OneAPI environment variables
 | 
					### 2. Configures OneAPI environment variables
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,10 +37,10 @@ import torch
 | 
				
			||||||
from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
 | 
					from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
 | 
				
			||||||
from ipex_llm.transformers import AutoModelForCausalLM
 | 
					from ipex_llm.transformers import AutoModelForCausalLM
 | 
				
			||||||
import transformers
 | 
					import transformers
 | 
				
			||||||
from transformers import AutoTokenizer, TrainingArguments, BitsAndBytesConfig
 | 
					from transformers import AutoTokenizer, BitsAndBytesConfig
 | 
				
			||||||
from datasets import load_dataset
 | 
					from datasets import load_dataset
 | 
				
			||||||
from peft import LoraConfig
 | 
					from peft import LoraConfig
 | 
				
			||||||
from trl import DPOTrainer
 | 
					from trl import DPOConfig, DPOTrainer
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -143,12 +143,15 @@ if __name__ == "__main__":
 | 
				
			||||||
    ref_model = ref_model.to('xpu')
 | 
					    ref_model = ref_model.to('xpu')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training arguments
 | 
					    # Training arguments
 | 
				
			||||||
    training_args = TrainingArguments(
 | 
					    training_args = DPOConfig(
 | 
				
			||||||
        per_device_train_batch_size=4,
 | 
					        per_device_train_batch_size=4,
 | 
				
			||||||
        gradient_accumulation_steps=4,
 | 
					        gradient_accumulation_steps=4,
 | 
				
			||||||
        gradient_checkpointing=gradient_checkpointing,
 | 
					        gradient_checkpointing=gradient_checkpointing,
 | 
				
			||||||
        learning_rate=5e-5,
 | 
					        learning_rate=5e-5,
 | 
				
			||||||
        lr_scheduler_type="cosine",
 | 
					        lr_scheduler_type="cosine",
 | 
				
			||||||
 | 
					        beta=0.1,
 | 
				
			||||||
 | 
					        max_prompt_length=1024,
 | 
				
			||||||
 | 
					        max_length=1536,
 | 
				
			||||||
        max_steps=200,
 | 
					        max_steps=200,
 | 
				
			||||||
        save_strategy="no",
 | 
					        save_strategy="no",
 | 
				
			||||||
        logging_steps=1,
 | 
					        logging_steps=1,
 | 
				
			||||||
| 
						 | 
					@ -166,9 +169,6 @@ if __name__ == "__main__":
 | 
				
			||||||
        args=training_args,
 | 
					        args=training_args,
 | 
				
			||||||
        train_dataset=dataset,
 | 
					        train_dataset=dataset,
 | 
				
			||||||
        tokenizer=tokenizer,
 | 
					        tokenizer=tokenizer,
 | 
				
			||||||
        beta=0.1,
 | 
					 | 
				
			||||||
        max_prompt_length=1024,
 | 
					 | 
				
			||||||
        max_length=1536,
 | 
					 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Fine-tune model with DPO
 | 
					    # Fine-tune model with DPO
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue