LLM: add alpaca qlora finetuning example (#9276)

This commit is contained in:
binbin Deng 2023-11-08 16:25:17 +08:00 committed by GitHub
parent 97316bbb66
commit 54d95e4907
7 changed files with 495 additions and 0 deletions

View file

@ -0,0 +1,50 @@
# Alpaca QLoRA Finetuning (experimental support)
This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM QLoRA on [Intel GPUs](../../README.md).
### 0. Requirements
To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../README.md#requirements) for more information.
### 1. Install
```bash
conda create -n llm python=3.9
conda activate llm
# below command will install intel_extension_for_pytorch==2.0.110+xpu as default
# you can install specific ipex/torch version for your need
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
pip install transformers==4.34.0
pip install fire datasets peft==0.5.0
pip install accelerate==0.23.0
```
### 2. Configures OneAPI environment variables
```bash
source /opt/intel/oneapi/setvars.sh
```
### 3. Finetuning LLaMA-2-7B on a single Arc:
Example usage:
```
python ./alpaca_qlora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca"
```
**Note**: You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file.
#### Sample Output
```log
{'loss': 1.9231, 'learning_rate': 2.9999945367033285e-05, 'epoch': 0.0}
{'loss': 1.8622, 'learning_rate': 2.9999781468531096e-05, 'epoch': 0.01}
{'loss': 1.9043, 'learning_rate': 2.9999508305687345e-05, 'epoch': 0.01}
{'loss': 1.8967, 'learning_rate': 2.999912588049185e-05, 'epoch': 0.01}
{'loss': 1.9658, 'learning_rate': 2.9998634195730358e-05, 'epoch': 0.01}
{'loss': 1.8386, 'learning_rate': 2.9998033254984483e-05, 'epoch': 0.02}
{'loss': 1.809, 'learning_rate': 2.999732306263172e-05, 'epoch': 0.02}
{'loss': 1.8552, 'learning_rate': 2.9996503623845395e-05, 'epoch': 0.02}
1%|█ | 8/1164 [xx:xx<xx:xx:xx, xx s/it]
```

View file

@ -0,0 +1,336 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/tloen/alpaca-lora/blob/main/finetune.py
#
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List
import fire
import torch
import transformers
from datasets import load_dataset
import accelerate
from transformers import LlamaTokenizer
from peft import (
LoraConfig,
get_peft_model_state_dict,
set_peft_model_state_dict,
)
from utils.prompter import Prompter
import intel_extension_for_pytorch as ipex
from bigdl.llm.transformers import AutoModelForCausalLM
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
def train(
# model/data params
base_model: str = "meta-llama/Llama-2-7b-hf", # the only required argument, default to be "meta-llama/Llama-2-7b-hf"
saved_low_bit_model: str = None, # optional, the path to the saved model with bigdl-llm low-bit optimization
data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "./bigdl-qlora-alpaca",
# training hyperparams
batch_size: int = 128,
micro_batch_size: int = 2, # default to be 2, limited by GPU memory
num_epochs: int = 3,
learning_rate: float = 3e-5, # default to be 3e-5 to avoid divergence
cutoff_len: int = 256,
val_set_size: int = 2000,
# lora hyperparams
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"up_proj",
"down_proj",
"gate_proj"
], # according to the QLoRA paper (https://arxiv.org/pdf/2305.14314.pdf), it's suggested to fine tune all linear layers
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = False,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
gradient_checkpointing: bool = False,
deepspeed: str = None,
):
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training Alpaca-LoRA model with params:\n"
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"lora_r: {lora_r}\n"
f"lora_alpha: {lora_alpha}\n"
f"lora_dropout: {lora_dropout}\n"
f"lora_target_modules: {lora_target_modules}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"add_eos_token: {add_eos_token}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
)
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size
prompter = Prompter(prompt_template_name)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or (
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
os.environ["WANDB_PROJECT"] = wandb_project
if len(wandb_watch) > 0:
os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
if saved_low_bit_model is not None:
# Load the low bit optimized model if provide the saved path
model = AutoModelForCausalLM.load_low_bit(
saved_low_bit_model,
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
)
else:
# Load the base model from a directory or the HF Hub to 4-bit NormalFloat format
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_low_bit="nf4", # According to the QLoRA paper, using "nf4" could yield better model quality than "int4"
optimize_model=False,
torch_dtype=torch.bfloat16,
# device_map=device_map,
modules_to_not_convert=["lm_head"],
)
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")
tokenizer = LlamaTokenizer.from_pretrained(base_model)
print(f"Tokenizer loaded on rank {os.environ.get('LOCAL_RANK')}")
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference
print(model)
def tokenize(prompt, add_eos_token=True):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
user_prompt = prompter.generate_prompt(
data_point["instruction"], data_point["input"]
)
tokenized_user_prompt = tokenize(
user_prompt, add_eos_token=add_eos_token
)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
if add_eos_token:
user_prompt_len -= 1
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
# Prepare a BigDL-LLM compatible Peft model
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=gradient_checkpointing)
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_path)
else:
data = load_dataset(data_path)
if resume_from_checkpoint:
# Check the available weights and load them
checkpoint_name = os.path.join(
resume_from_checkpoint, "pytorch_model.bin"
) # Full checkpoint
if not os.path.exists(checkpoint_name):
checkpoint_name = os.path.join(
resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
resume_from_checkpoint = (
False # So the trainer won't try loading its state
)
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
set_peft_model_state_dict(model, adapters_weights)
else:
print(f"Checkpoint {checkpoint_name} not found")
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
if val_set_size > 0:
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=42
)
train_data = (
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
)
val_data = (
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = None
# Unused
# if not ddp and torch.cuda.device_count() > 1:
# # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
# model.is_parallelizable = True
# model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
# warmup_ratio=0.03,
# warmup_steps=100,
max_grad_norm=0.3,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
lr_scheduler_type="cosine",
bf16=True, # ensure training more stable
logging_steps=1,
optim="adamw_torch",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=100 if val_set_size > 0 else None,
save_steps=100,
output_dir=output_dir,
save_total_limit=100,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else None,
run_name=wandb_run_name if use_wandb else None,
gradient_checkpointing=gradient_checkpointing,
ddp_backend="ccl",
deepspeed=deepspeed,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
model.config.use_cache = False
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
model.save_pretrained(output_dir)
print(
"\n If there's a warning about missing keys above, please disregard :)"
)
if __name__ == "__main__":
fire.Fire(train)

View file

@ -0,0 +1,7 @@
{
"//": "This file is copied from https://github.com/tloen/alpaca-lora/blob/main/templates/alpaca.json",
"description": "Template used by Alpaca-LoRA.",
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
"response_split": "### Response:"
}

View file

@ -0,0 +1,7 @@
{
"//": "This file is copied from https://github.com/tloen/alpaca-lora/blob/main/templates/alpaca_legacy.json",
"description": "Legacy template, used by Original Alpaca repository.",
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:",
"response_split": "### Response:"
}

View file

@ -0,0 +1,7 @@
{
"//": "This file is copied from https://github.com/tloen/alpaca-lora/blob/main/templates/alpaca_short.json",
"description": "A shorter template to experiment with.",
"prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
"prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n",
"response_split": "### Response:"
}

View file

@ -0,0 +1,7 @@
{
"//": "This file is copied from https://github.com/tloen/alpaca-lora/blob/main/templates/vigogne.json",
"description": "French template, used by Vigogne for finetuning.",
"prompt_input": "Ci-dessous se trouve une instruction qui décrit une tâche, associée à une entrée qui fournit un contexte supplémentaire. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Entrée:\n{input}\n\n### Réponse:\n",
"prompt_no_input": "Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Réponse:\n",
"response_split": "### Réponse:"
}

View file

@ -0,0 +1,81 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/tloen/alpaca-lora/blob/main/utils/prompter.py
#
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os.path as osp
from typing import Union
from bigdl.llm.utils.common import invalidInputError
class Prompter(object):
__slots__ = ("template", "_verbose")
def __init__(self, template_name: str = "", verbose: bool = False):
self._verbose = verbose
if not template_name:
# Enforce the default here, so the constructor can be called with '' and will not break.
template_name = "alpaca"
file_name = osp.join("templates", f"{template_name}.json")
if not osp.exists(file_name):
invalidInputError(False, f"Can't read {file_name}")
with open(file_name) as fp:
self.template = json.load(fp)
if self._verbose:
print(
f"Using prompt template {template_name}: {self.template['description']}"
)
def generate_prompt(
self,
instruction: str,
input: Union[None, str]=None,
label: Union[None, str]=None,
) -> str:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.template["prompt_input"].format(
instruction=instruction, input=input
)
else:
res = self.template["prompt_no_input"].format(
instruction=instruction
)
if label:
res = f"{res}{label}"
if self._verbose:
print(res)
return res
def get_response(self, output: str) -> str:
return output.split(self.template["response_split"])[1].strip()