support ccl (MPI) distributed mode in alpaca_qlora_finetuning_cpu (#9507)
This commit is contained in:
parent
0f0c6bb631
commit
48fbb1eb94
1 changed files with 13 additions and 11 deletions
|
|
@ -91,8 +91,7 @@ def train(
|
||||||
wandb_log_model: str = "", # options: false | true
|
wandb_log_model: str = "", # options: false | true
|
||||||
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
||||||
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
|
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
|
||||||
gradient_checkpointing: bool = False,
|
gradient_checkpointing: bool = False
|
||||||
deepspeed: str = None,
|
|
||||||
):
|
):
|
||||||
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
||||||
print(
|
print(
|
||||||
|
|
@ -279,12 +278,7 @@ def train(
|
||||||
else:
|
else:
|
||||||
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||||
val_data = None
|
val_data = None
|
||||||
|
args = transformers.TrainingArguments(
|
||||||
trainer = transformers.Trainer(
|
|
||||||
model=model,
|
|
||||||
train_dataset=train_data,
|
|
||||||
eval_dataset=val_data,
|
|
||||||
args=transformers.TrainingArguments(
|
|
||||||
per_device_train_batch_size=micro_batch_size,
|
per_device_train_batch_size=micro_batch_size,
|
||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
# warmup_ratio=0.03,
|
# warmup_ratio=0.03,
|
||||||
|
|
@ -309,9 +303,17 @@ def train(
|
||||||
report_to="wandb" if use_wandb else None,
|
report_to="wandb" if use_wandb else None,
|
||||||
run_name=wandb_run_name if use_wandb else None,
|
run_name=wandb_run_name if use_wandb else None,
|
||||||
gradient_checkpointing=gradient_checkpointing,
|
gradient_checkpointing=gradient_checkpointing,
|
||||||
ddp_backend="ccl",
|
ddp_backend="ccl" if ddp else None,
|
||||||
deepspeed=deepspeed,
|
)
|
||||||
),
|
if ddp:
|
||||||
|
from accelerate.state import PartialState
|
||||||
|
args.distributed_state = PartialState(cpu=True, backend=args.ddp_backend)
|
||||||
|
|
||||||
|
trainer = transformers.Trainer(
|
||||||
|
model=model,
|
||||||
|
train_dataset=train_data,
|
||||||
|
eval_dataset=val_data,
|
||||||
|
args=args,
|
||||||
# Inputs are dynamically padded to the maximum length among all inputs
|
# Inputs are dynamically padded to the maximum length among all inputs
|
||||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue