support ccl (MPI) distributed mode in alpaca_qlora_finetuning_cpu (#9507)
This commit is contained in:
parent
0f0c6bb631
commit
48fbb1eb94
1 changed files with 13 additions and 11 deletions
|
|
@ -91,8 +91,7 @@ def train(
|
|||
wandb_log_model: str = "", # options: false | true
|
||||
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
||||
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
|
||||
gradient_checkpointing: bool = False,
|
||||
deepspeed: str = None,
|
||||
gradient_checkpointing: bool = False
|
||||
):
|
||||
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
||||
print(
|
||||
|
|
@ -279,11 +278,6 @@ def train(
|
|||
else:
|
||||
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||
val_data = None
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=train_data,
|
||||
eval_dataset=val_data,
|
||||
args = transformers.TrainingArguments(
|
||||
per_device_train_batch_size=micro_batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
|
|
@ -309,9 +303,17 @@ def train(
|
|||
report_to="wandb" if use_wandb else None,
|
||||
run_name=wandb_run_name if use_wandb else None,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
ddp_backend="ccl",
|
||||
deepspeed=deepspeed,
|
||||
),
|
||||
ddp_backend="ccl" if ddp else None,
|
||||
)
|
||||
if ddp:
|
||||
from accelerate.state import PartialState
|
||||
args.distributed_state = PartialState(cpu=True, backend=args.ddp_backend)
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=train_data,
|
||||
eval_dataset=val_data,
|
||||
args=args,
|
||||
# Inputs are dynamically padded to the maximum length among all inputs
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||
|
|
|
|||
Loading…
Reference in a new issue