LLM: support resume from checkpoint in Alpaca QLoRA (#9502)

This commit is contained in:
binbin Deng 2023-11-22 13:49:14 +08:00 committed by GitHub
parent 139e98aa18
commit 1a2129221d
2 changed files with 12 additions and 22 deletions

View file

@ -13,8 +13,8 @@ conda activate llm
# below command will install intel_extension_for_pytorch==2.0.110+xpu as default
# you can install specific ipex/torch version for your need
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
pip install transformers==4.34.0
pip install fire datasets peft==0.5.0
pip install datasets transformers==4.34.0
pip install fire peft==0.5.0
pip install oneccl_bind_pt==2.0.100 -f https://developer.intel.com/ipex-whl-stable-xpu # necessary to run distributed finetuning
pip install accelerate==0.23.0
```
@ -76,6 +76,15 @@ bash finetune_llama2_7b_pvc_1550_1_card.sh
bash finetune_llama2_7b_pvc_1550_4_card.sh
```
**Important: If you fail to complete the whole finetuning process, it is suggested to resume training from a previously saved checkpoint by specifying `resume_from_checkpoint` to the local checkpoint folder as following:**
```bash
python ./alpaca_qlora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
--resume_from_checkpoint "./bigdl-qlora-alpaca/checkpoint-1100"
```
### 4. Sample Output
```log
{'loss': 1.9231, 'learning_rate': 2.9999945367033285e-05, 'epoch': 0.0}

View file

@ -261,26 +261,6 @@ def train(
else:
data = load_dataset(data_path)
if resume_from_checkpoint:
# Check the available weights and load them
checkpoint_name = os.path.join(
resume_from_checkpoint, "pytorch_model.bin"
) # Full checkpoint
if not os.path.exists(checkpoint_name):
checkpoint_name = os.path.join(
resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
resume_from_checkpoint = (
False # So the trainer won't try loading its state
)
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
set_peft_model_state_dict(model, adapters_weights)
else:
print(f"Checkpoint {checkpoint_name} not found")
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
if val_set_size > 0:
@ -336,6 +316,7 @@ def train(
gradient_checkpointing=gradient_checkpointing,
ddp_backend="ccl",
deepspeed=deepspeed,
save_safetensors=False,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True