LLM : Add CPU alpaca qlora example (#9469)

* init

* update xpu to cpu

* update

* update readme

* update example

* update

* add refer

* add guide to train different datasets

* update readme

* update
This commit is contained in:
Wang, Jian4 2023-11-21 09:19:58 +08:00 committed by GitHub
parent 96fd26759c
commit c5cb3ab82e
10 changed files with 577 additions and 3 deletions

View file

@ -19,8 +19,8 @@ pip install datasets
```
### 2. Finetune model
If the machine memory is not enough, you can try to set `use_gradient_checkpointing=True` in [here](https://github.com/intel-analytics/BigDL/blob/1747ffe60019567482b6976a24b05079274e7fc8/python/llm/example/CPU/QLoRA-FineTuning/qlora_finetuning_cpu.py#L53C6-L53C6).
If the machine memory is not enough, you can try to set `use_gradient_checkpointing=True` in [here](https://github.com/intel-analytics/BigDL/blob/1747ffe60019567482b6976a24b05079274e7fc8/python/llm/example/CPU/QLoRA-FineTuning/qlora_finetuning_cpu.py#L53C6-L53C6). While gradient checkpointing may improve memory efficiency, it slows training by approximately 20%.
We Recommend using micro_batch_size of 8 for better performance using 48cores in this example. You can refer to [this guide](https://huggingface.co/docs/transformers/perf_train_gpu_one) for more details.
And remember to use `bigdl-llm-init` before you start finetuning, which can accelerate the job.
```

View file

@ -0,0 +1,85 @@
# Alpaca QLoRA Finetuning (experimental support)
This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM QLoRA on [Intel CPUs](../../README.md).
### 1. Install
```bash
conda create -n llm python=3.9
conda activate llm
pip install --pre --upgrade bigdl-llm[all]
pip install transformers==4.34.0
pip install fire datasets peft==0.5.0
pip install accelerate==0.23.0
```
### 2. Configures environment variables
```bash
source bigdl-llm-init -t
```
### 3. Finetuning LLaMA-2-7B on a node:
Example usage:
```
python ./alpaca_qlora_finetuning_cpu.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca"
```
**Note**: You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file.
#### Sample Output
```log
{'loss': 1.9231, 'learning_rate': 2.9999945367033285e-05, 'epoch': 0.0}
{'loss': 1.8622, 'learning_rate': 2.9999781468531096e-05, 'epoch': 0.01}
{'loss': 1.9043, 'learning_rate': 2.9999508305687345e-05, 'epoch': 0.01}
{'loss': 1.8967, 'learning_rate': 2.999912588049185e-05, 'epoch': 0.01}
{'loss': 1.9658, 'learning_rate': 2.9998634195730358e-05, 'epoch': 0.01}
{'loss': 1.8386, 'learning_rate': 2.9998033254984483e-05, 'epoch': 0.02}
{'loss': 1.809, 'learning_rate': 2.999732306263172e-05, 'epoch': 0.02}
{'loss': 1.8552, 'learning_rate': 2.9996503623845395e-05, 'epoch': 0.02}
1%|█ | 8/1164 [xx:xx<xx:xx:xx, xx s/it]
```
### Guide to use different prompts or different datasets
Now the prompter is for the datasets with `instruction` `input`(optional) and `output`. If you want to use different datasets,
you can add template file xxx.json in templates. And then update utils.prompter.py's `generate_prompt` method and update `generate_and_tokenize_prompt` method to fix the dataset.
For example, I want to train llama2-7b with [english_quotes](https://huggingface.co/datasets/Abirate/english_quotes) just like [this example](https://github.com/intel-analytics/BigDL/blob/main/python/llm/example/CPU/QLoRA-FineTuning/qlora_finetuning_cpu.py)
1. add template english_quotes.json
```json
{
"prompt": "{quote} ->: {tags}"
}
```
2. update prompter.py and add new generate_prompt method
```python
def generate_quote_prompt(self, quote: str, tags: Union[None, list]=None,) -> str:
tags = str(tags)
res = self.template["prompt"].format(
quote=quote, tags=tags
)
if self._verbose:
print(res)
return res
```
3. update generate_and_tokenize_prompt method
```python
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_quote_prompt(
data_point["quote"], data_point["tags"]
)
user_prompt = prompter.generate_quote_prompt(
data_point["quote"], data_point["tags"]
)
```
4. choose prompt `english_quotes` to train
```bash
python ./quotes_qlora_finetuning_cpu.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "./english_quotes" \
--output_dir "./bigdl-qlora-alpaca" \
--prompt_template_name "english_quotes"
```

View file

@ -0,0 +1,332 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/tloen/alpaca-lora/blob/main/finetune.py
#
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List
import fire
import torch
import transformers
from datasets import load_dataset
import accelerate
from transformers import LlamaTokenizer
from peft import (
LoraConfig,
get_peft_model_state_dict,
set_peft_model_state_dict,
)
from utils.prompter import Prompter
from bigdl.llm.transformers import AutoModelForCausalLM
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
def train(
# model/data params
base_model: str = "meta-llama/Llama-2-7b-hf", # the only required argument, default to be "meta-llama/Llama-2-7b-hf"
saved_low_bit_model: str = None, # optional, the path to the saved model with bigdl-llm low-bit optimization
data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "./bigdl-qlora-alpaca",
# training hyperparams
batch_size: int = 128,
micro_batch_size: int = 2, # default to be 2, limited by CPU memory
num_epochs: int = 3,
max_steps: int = -1, # if set to a positive number, it will verride num_train_epochs
learning_rate: float = 3e-5, # default to be 3e-5 to avoid divergence
cutoff_len: int = 256,
val_set_size: int = 2000,
# lora hyperparams
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"up_proj",
"down_proj",
"gate_proj"
], # according to the QLoRA paper (https://arxiv.org/pdf/2305.14314.pdf), it's suggested to fine tune all linear layers
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = False,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
gradient_checkpointing: bool = False,
deepspeed: str = None,
):
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training Alpaca-LoRA model with params:\n"
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"lora_r: {lora_r}\n"
f"lora_alpha: {lora_alpha}\n"
f"lora_dropout: {lora_dropout}\n"
f"lora_target_modules: {lora_target_modules}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"add_eos_token: {add_eos_token}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
)
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size
prompter = Prompter(prompt_template_name)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or (
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
os.environ["WANDB_PROJECT"] = wandb_project
if len(wandb_watch) > 0:
os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
if saved_low_bit_model is not None:
# Load the low bit optimized model if provide the saved path
model = AutoModelForCausalLM.load_low_bit(
saved_low_bit_model,
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
)
else:
# Load the base model from a directory or the HF Hub to 4-bit NormalFloat format
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_low_bit="sym_int4", # not support "nf4"
optimize_model=False,
torch_dtype=torch.bfloat16,
# device_map=device_map,
modules_to_not_convert=["lm_head"],
)
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
model = model.to("cpu")
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")
tokenizer = LlamaTokenizer.from_pretrained(base_model)
print(f"Tokenizer loaded on rank {os.environ.get('LOCAL_RANK')}")
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference
print(model)
def tokenize(prompt, add_eos_token=True):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
user_prompt = prompter.generate_prompt(
data_point["instruction"], data_point["input"]
)
tokenized_user_prompt = tokenize(
user_prompt, add_eos_token=add_eos_token
)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
if add_eos_token:
user_prompt_len -= 1
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
# Prepare a BigDL-LLM compatible Peft model
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=gradient_checkpointing)
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_path)
else:
data = load_dataset(data_path)
if resume_from_checkpoint:
# Check the available weights and load them
checkpoint_name = os.path.join(
resume_from_checkpoint, "pytorch_model.bin"
) # Full checkpoint
if not os.path.exists(checkpoint_name):
checkpoint_name = os.path.join(
resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
resume_from_checkpoint = (
False # So the trainer won't try loading its state
)
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
set_peft_model_state_dict(model, adapters_weights)
else:
print(f"Checkpoint {checkpoint_name} not found")
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
if val_set_size > 0:
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=42
)
train_data = (
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
)
val_data = (
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = None
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
# warmup_ratio=0.03,
# warmup_steps=100,
max_steps=max_steps,
max_grad_norm=0.3,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
lr_scheduler_type="cosine",
bf16=True, # ensure training more stable
logging_steps=1,
optim="adamw_torch",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=100 if val_set_size > 0 else None,
save_steps=100,
output_dir=output_dir,
save_total_limit=100,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else None,
run_name=wandb_run_name if use_wandb else None,
gradient_checkpointing=gradient_checkpointing,
ddp_backend="ccl",
deepspeed=deepspeed,
),
# Inputs are dynamically padded to the maximum length among all inputs
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
model.config.use_cache = False
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
model.save_pretrained(output_dir)
print(
"\n If there's a warning about missing keys above, please disregard :)"
)
if __name__ == "__main__":
fire.Fire(train)

View file

@ -0,0 +1,46 @@
# Prompt templates
This directory contains template styles for the prompts used to finetune LoRA models.
## Format
A template is described via a JSON file with the following keys:
- `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders.
- `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders.
- `description`: A short description of the template, with possible use cases.
- `response_split`: The text to use as separator when cutting real response from the model output.
No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest.
## Example template
The default template, used unless otherwise specified, is `alpaca.json`
```json
{
"description": "Template used by Alpaca-LoRA.",
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
"response_split": "### Response:"
}
```
## Current templates
### alpaca
Default template used for generic LoRA fine tunes so far.
### alpaca_legacy
Legacy template used by the original alpaca repo, with no `\n` after the response field. Kept for reference and experiments.
### alpaca_short
A trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome.
### vigogne
The default alpaca template, translated to french. This template was used to train the "Vigogne" LoRA and is to be used to query it, or for extra fine tuning.

View file

@ -0,0 +1,7 @@
{
"//": "This file is copied from https://github.com/tloen/alpaca-lora/blob/main/templates/alpaca.json",
"description": "Template used by Alpaca-LoRA.",
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
"response_split": "### Response:"
}

View file

@ -0,0 +1,7 @@
{
"//": "This file is copied from https://github.com/tloen/alpaca-lora/blob/main/templates/alpaca_legacy.json",
"description": "Legacy template, used by Original Alpaca repository.",
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:",
"response_split": "### Response:"
}

View file

@ -0,0 +1,7 @@
{
"//": "This file is copied from https://github.com/tloen/alpaca-lora/blob/main/templates/alpaca_short.json",
"description": "A shorter template to experiment with.",
"prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
"prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n",
"response_split": "### Response:"
}

View file

@ -0,0 +1,7 @@
{
"//": "This file is copied from https://github.com/tloen/alpaca-lora/blob/main/templates/vigogne.json",
"description": "French template, used by Vigogne for finetuning.",
"prompt_input": "Ci-dessous se trouve une instruction qui décrit une tâche, associée à une entrée qui fournit un contexte supplémentaire. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Entrée:\n{input}\n\n### Réponse:\n",
"prompt_no_input": "Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Réponse:\n",
"response_split": "### Réponse:"
}

View file

@ -0,0 +1,81 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/tloen/alpaca-lora/blob/main/utils/prompter.py
#
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os.path as osp
from typing import Union
from bigdl.llm.utils.common import invalidInputError
class Prompter(object):
__slots__ = ("template", "_verbose")
def __init__(self, template_name: str = "", verbose: bool = False):
self._verbose = verbose
if not template_name:
# Enforce the default here, so the constructor can be called with '' and will not break.
template_name = "alpaca"
file_name = osp.join("templates", f"{template_name}.json")
if not osp.exists(file_name):
invalidInputError(False, f"Can't read {file_name}")
with open(file_name) as fp:
self.template = json.load(fp)
if self._verbose:
print(
f"Using prompt template {template_name}: {self.template['description']}"
)
def generate_prompt(
self,
instruction: str,
input: Union[None, str]=None,
label: Union[None, str]=None,
) -> str:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.template["prompt_input"].format(
instruction=instruction, input=input
)
else:
res = self.template["prompt_no_input"].format(
instruction=instruction
)
if label:
res = f"{res}{label}"
if self._verbose:
print(res)
return res
def get_response(self, output: str) -> str:
return output.split(self.template["response_split"])[1].strip()

View file

@ -44,7 +44,8 @@ if __name__ == "__main__":
row['prediction'] = row['quote'] + ' ->: ' + str(row['tags'])
return row
data['train'] = data['train'].map(merge)
data = data.map(lambda samples: tokenizer(samples["prediction"]), batched=True)
# use the max_length to reduce memory usage, should be adjusted by different datasets
data = data.map(lambda samples: tokenizer(samples["prediction"], max_length=256), batched=True)
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit="sym_int4",
optimize_model=False,
@ -85,6 +86,7 @@ if __name__ == "__main__":
optim="adamw_hf", # paged_adamw_8bit is not supported yet
# gradient_checkpointing=True, # can further reduce memory but slower
),
# Inputs are dynamically padded to the maximum length of a batch
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!