LLM: support resume from checkpoint in Alpaca QLoRA (#9502)
This commit is contained in:
parent
139e98aa18
commit
1a2129221d
2 changed files with 12 additions and 22 deletions
|
|
@ -13,8 +13,8 @@ conda activate llm
|
|||
# below command will install intel_extension_for_pytorch==2.0.110+xpu as default
|
||||
# you can install specific ipex/torch version for your need
|
||||
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
|
||||
pip install transformers==4.34.0
|
||||
pip install fire datasets peft==0.5.0
|
||||
pip install datasets transformers==4.34.0
|
||||
pip install fire peft==0.5.0
|
||||
pip install oneccl_bind_pt==2.0.100 -f https://developer.intel.com/ipex-whl-stable-xpu # necessary to run distributed finetuning
|
||||
pip install accelerate==0.23.0
|
||||
```
|
||||
|
|
@ -76,6 +76,15 @@ bash finetune_llama2_7b_pvc_1550_1_card.sh
|
|||
bash finetune_llama2_7b_pvc_1550_4_card.sh
|
||||
```
|
||||
|
||||
**Important: If you fail to complete the whole finetuning process, it is suggested to resume training from a previously saved checkpoint by specifying `resume_from_checkpoint` to the local checkpoint folder as following:**
|
||||
```bash
|
||||
python ./alpaca_qlora_finetuning.py \
|
||||
--base_model "meta-llama/Llama-2-7b-hf" \
|
||||
--data_path "yahma/alpaca-cleaned" \
|
||||
--output_dir "./bigdl-qlora-alpaca" \
|
||||
--resume_from_checkpoint "./bigdl-qlora-alpaca/checkpoint-1100"
|
||||
```
|
||||
|
||||
### 4. Sample Output
|
||||
```log
|
||||
{'loss': 1.9231, 'learning_rate': 2.9999945367033285e-05, 'epoch': 0.0}
|
||||
|
|
|
|||
|
|
@ -261,26 +261,6 @@ def train(
|
|||
else:
|
||||
data = load_dataset(data_path)
|
||||
|
||||
if resume_from_checkpoint:
|
||||
# Check the available weights and load them
|
||||
checkpoint_name = os.path.join(
|
||||
resume_from_checkpoint, "pytorch_model.bin"
|
||||
) # Full checkpoint
|
||||
if not os.path.exists(checkpoint_name):
|
||||
checkpoint_name = os.path.join(
|
||||
resume_from_checkpoint, "adapter_model.bin"
|
||||
) # only LoRA model - LoRA config above has to fit
|
||||
resume_from_checkpoint = (
|
||||
False # So the trainer won't try loading its state
|
||||
)
|
||||
# The two files above have a different name depending on how they were saved, but are actually the same.
|
||||
if os.path.exists(checkpoint_name):
|
||||
print(f"Restarting from {checkpoint_name}")
|
||||
adapters_weights = torch.load(checkpoint_name)
|
||||
set_peft_model_state_dict(model, adapters_weights)
|
||||
else:
|
||||
print(f"Checkpoint {checkpoint_name} not found")
|
||||
|
||||
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
||||
|
||||
if val_set_size > 0:
|
||||
|
|
@ -336,6 +316,7 @@ def train(
|
|||
gradient_checkpointing=gradient_checkpointing,
|
||||
ddp_backend="ccl",
|
||||
deepspeed=deepspeed,
|
||||
save_safetensors=False,
|
||||
),
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||
|
|
|
|||
Loading…
Reference in a new issue