diff --git a/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning_cpu.py b/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning_cpu.py index ee39cd7f..5ffae73e 100644 --- a/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning_cpu.py +++ b/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning_cpu.py @@ -91,8 +91,7 @@ def train( wandb_log_model: str = "", # options: false | true resume_from_checkpoint: str = None, # either training checkpoint or final adapter prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca. - gradient_checkpointing: bool = False, - deepspeed: str = None, + gradient_checkpointing: bool = False ): if int(os.environ.get("LOCAL_RANK", 0)) == 0: print( @@ -279,12 +278,7 @@ def train( else: train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) val_data = None - - trainer = transformers.Trainer( - model=model, - train_dataset=train_data, - eval_dataset=val_data, - args=transformers.TrainingArguments( + args = transformers.TrainingArguments( per_device_train_batch_size=micro_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, # warmup_ratio=0.03, @@ -309,9 +303,17 @@ def train( report_to="wandb" if use_wandb else None, run_name=wandb_run_name if use_wandb else None, gradient_checkpointing=gradient_checkpointing, - ddp_backend="ccl", - deepspeed=deepspeed, - ), + ddp_backend="ccl" if ddp else None, + ) + if ddp: + from accelerate.state import PartialState + args.distributed_state = PartialState(cpu=True, backend=args.ddp_backend) + + trainer = transformers.Trainer( + model=model, + train_dataset=train_data, + eval_dataset=val_data, + args=args, # Inputs are dynamically padded to the maximum length among all inputs data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True