LLM: reorganize GPU finetuning examples (#9952)

This commit is contained in:
binbin Deng 2024-01-25 19:02:38 +08:00 committed by GitHub
parent 175027c90f
commit 171fb2d185
60 changed files with 1895 additions and 378 deletions

View file

@ -13,13 +13,13 @@
### Latest update 🔥
- [2024/01] 🔔🔔🔔 ***Starting from 2024/01/08, the default `bigdl-llm` GPU Linux installation switched from PyTorch 2.0 to PyTorch 2.1, which requires new oneAPI and GPU driver versions. (See the [GPU installation guide](https://bigdl.readthedocs.io/en/latest/doc/LLM/Overview/install_gpu.html) for more details.)***
- [2023/12] `bigdl-llm` now supports [ReLoRA](python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora#relora) (see *["ReLoRA: High-Rank Training Through Low-Rank Updates"](https://arxiv.org/abs/2307.05695)*)
- [2023/12] `bigdl-llm` now supports [ReLoRA](python/llm/example/GPU/LLM-Finetuning/ReLora) (see *["ReLoRA: High-Rank Training Through Low-Rank Updates"](https://arxiv.org/abs/2307.05695)*)
- [2023/12] `bigdl-llm` now supports [Mixtral-8x7B](python/llm/example/GPU/HF-Transformers-AutoModels/Model/mixtral) on both Intel [GPU](python/llm/example/GPU/HF-Transformers-AutoModels/Model/mixtral) and [CPU](python/llm/example/CPU/HF-Transformers-AutoModels/Model/mixtral).
- [2023/12] `bigdl-llm` now supports [QA-LoRA](python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora#qa-lora) (see *["QA-LoRA: Quantization-Aware Low-Rank Adaptation of Large Language Models"](https://arxiv.org/abs/2309.14717)*)
- [2023/12] `bigdl-llm` now supports [QA-LoRA](python/llm/example/GPU/LLM-Finetuning/QA-LoRA) (see *["QA-LoRA: Quantization-Aware Low-Rank Adaptation of Large Language Models"](https://arxiv.org/abs/2309.14717)*)
- [2023/12] `bigdl-llm` now supports [FP8 and FP4 inference](python/llm/example/GPU/HF-Transformers-AutoModels/More-Data-Types) on Intel ***GPU***.
- [2023/11] Initial support for directly loading [GGUF](python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF), [AWQ](python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/AWQ) and [GPTQ](python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ) models into `bigdl-llm` is available.
- [2023/11] `bigdl-llm` now supports [vLLM continuous batching](python/llm/example/GPU/vLLM-Serving) on both Intel [GPU](python/llm/example/GPU/vLLM-Serving) and [CPU](python/llm/example/CPU/vLLM-Serving).
- [2023/10] `bigdl-llm` now supports [QLoRA finetuning](python/llm/example/GPU/QLoRA-FineTuning) on both Intel [GPU](python/llm/example/GPU/QLoRA-FineTuning) and [CPU](python/llm/example/CPU/QLoRA-FineTuning).
- [2023/10] `bigdl-llm` now supports [QLoRA finetuning](python/llm/example/GPU/LLM-Finetuning/QLoRA) on both Intel [GPU](python/llm/example/GPU/LLM-Finetuning/QLoRA) and [CPU](python/llm/example/CPU/QLoRA-FineTuning).
- [2023/10] `bigdl-llm` now supports [FastChat serving](python/llm/src/bigdl/llm/serving) on on both Intel CPU and GPU.
- [2023/09] `bigdl-llm` now supports [Intel GPU](python/llm/example/GPU) (including Arc, Flex and MAX)
- [2023/09] `bigdl-llm` [tutorial](https://github.com/intel-analytics/bigdl-llm-tutorial) is released.

View file

@ -109,7 +109,7 @@ TrainOutput(global_step=200, training_loss=1.5072882556915284, metrics={'train_r
### 4. Merge the adapter into the original model
Using the [export_merged_model.py](https://github.com/intel-analytics/BigDL/blob/main/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py) to merge.
Using the [export_merged_model.py](../../../../../../python/llm/example/GPU/LLM-Finetuning/QLoRA/export_merged_model.py) to merge.
```
python ./export_merged_model.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --adapter_path ./outputs/checkpoint-200 --output_path ./outputs/checkpoint-200-merged

View file

@ -33,6 +33,6 @@ RUN curl -fsSL https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-P
# install huggingface dependencies
pip install git+https://github.com/huggingface/transformers.git@${TRANSFORMERS_COMMIT_ID} && \
pip install peft==0.5.0 datasets && \
wget https://raw.githubusercontent.com/intel-analytics/BigDL/main/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py
wget https://raw.githubusercontent.com/intel-analytics/BigDL/main/python/llm/example/GPU/LLM-Finetuning/QLoRA/simple-example/qlora_finetuning.py
COPY ./start-qlora-finetuning-on-xpu.sh /start-qlora-finetuning-on-xpu.sh

View file

@ -25,13 +25,13 @@ BigDL-LLM: low-Bit LLM library
Latest update 🔥
============================================
- [2024/01] 🔔🔔🔔 **Starting from 2024/01/08, the default** ``bigdl-llm`` **GPU Linux installation switched from PyTorch 2.0 to PyTorch 2.1, which requires new oneAPI and GPU driver versions. (See the** `GPU installation guide <https://bigdl.readthedocs.io/en/latest/doc/LLM/Overview/install_gpu.html>`_ **for more details.)**
- [2023/12] ``bigdl-llm`` now supports `ReLoRA <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora#relora>`_ (see `"ReLoRA: High-Rank Training Through Low-Rank Updates" <https://arxiv.org/abs/2307.05695>`_)
- [2023/12] ``bigdl-llm`` now supports `ReLoRA <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/LLM-Finetuning/ReLora>`_ (see `"ReLoRA: High-Rank Training Through Low-Rank Updates" <https://arxiv.org/abs/2307.05695>`_)
- [2023/12] ``bigdl-llm`` now supports `Mixtral-8x7B <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/HF-Transformers-AutoModels/Model/mixtral>`_ on both Intel `GPU <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/HF-Transformers-AutoModels/Model/mixtral>`_ and `CPU <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/HF-Transformers-AutoModels/Model/mixtral>`_.
- [2023/12] ``bigdl-llm`` now supports `QA-LoRA <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora#qa-lora>`_ (see `"QA-LoRA: Quantization-Aware Low-Rank Adaptation of Large Language Models" <https://arxiv.org/abs/2309.14717>`_).
- [2023/12] ``bigdl-llm`` now supports `QA-LoRA <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/LLM-Finetuning/QA-LoRA>`_ (see `"QA-LoRA: Quantization-Aware Low-Rank Adaptation of Large Language Models" <https://arxiv.org/abs/2309.14717>`_).
- [2023/12] ``bigdl-llm`` now supports `FP8 and FP4 inference <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/HF-Transformers-AutoModels/More-Data-Types>`_ on Intel **GPU**.
- [2023/11] Initial support for directly loading `GGUF <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF>`_, `AWQ <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/AWQ>`_ and `GPTQ <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ>`_ models in to ``bigdl-llm`` is available.
- [2023/11] ``bigdl-llm`` now supports `vLLM continuous batching <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/vLLM-Serving>`_ on both Intel `GPU <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/vLLM-Serving>`_ and `CPU <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/vLLM-Serving>`_.
- [2023/10] ``bigdl-llm`` now supports `QLoRA finetuning <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/QLoRA-FineTuning>`_ on both Intel `GPU <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/QLoRA-FineTuning>`_ and `CPU <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/QLoRA-FineTuning>`_.
- [2023/10] ``bigdl-llm`` now supports `QLoRA finetuning <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/LLM-Finetuning/QLoRA>`_ on both Intel `GPU <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU/LLM-Finetuning/QLoRA>`_ and `CPU <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/QLoRA-FineTuning>`_.
- [2023/10] ``bigdl-llm`` now supports `FastChat serving <https://github.com/intel-analytics/BigDL/tree/main/python/llm/src/bigdl/llm/serving>`_ on on both Intel CPU and GPU.
- [2023/09] ``bigdl-llm`` now supports `Intel GPU <https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU>`_ (including Arc, Flex and MAX)
- [2023/09] ``bigdl-llm`` `tutorial <https://github.com/intel-analytics/bigdl-llm-tutorial>`_ is released.

View file

@ -54,7 +54,7 @@ TrainOutput(global_step=200, training_loss=1.3923714351654053, metrics={'train_r
```
### 3. Merge the adapter into the original model
Using the [export_merged_model.py](https://github.com/intel-analytics/BigDL/blob/main/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py) to merge.
Using the [export_merged_model.py](https://github.com/intel-analytics/BigDL/blob/main/python/llm/example/GPU/LLM-Finetuning/QLoRA/export_merged_model.py) to merge.
```
python ./export_merged_model.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --adapter_path ./outputs/checkpoint-200 --output_path ./outputs/checkpoint-200-merged
```

View file

@ -143,7 +143,7 @@ lora_target_modules: List[str] = ["W_pack"]
5. (Only for baichuan) According to this [issue](https://github.com/baichuan-inc/Baichuan2/issues/204#issuecomment-1774372008),
need to modify the [tokenization_baichuan.py](https://huggingface.co/baichuan-inc/Baichuan-7B/blob/main/tokenization_baichuan.py#L74) to fix issue.
6. finetune as normal
7. Using the [export_merged_model.py](https://github.com/intel-analytics/BigDL/blob/main/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py) to merge. But also need to update tokenizer and model to ensure successful merge weight.
7. Using the [export_merged_model.py](https://github.com/intel-analytics/BigDL/blob/main/python/llm/example/GPU/LLM-Finetuning/QLoRA/export_merged_model.py) to merge. But also need to update tokenizer and model to ensure successful merge weight.
```bash
from transformers import AutoTokenizer # noqa: F402

View file

@ -0,0 +1,90 @@
# LoRA Finetuning with BigDL-LLM
This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM (using [LoRA](https://arxiv.org/abs/2106.09685) algorithm) on [Intel GPU](../../README.md).
### 0. Requirements
To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../README.md#requirements) for more information.
### 1. Install
```bash
conda create -n llm python=3.9
conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
pip install transformers==4.34.0 datasets
pip install fire peft==0.5.0
pip install oneccl_bind_pt==2.1.100 -f https://developer.intel.com/ipex-whl-stable-xpu # necessary to run distributed finetuning
pip install accelerate==0.23.0
pip install bitsandbytes scipy
```
### 2. Configures OneAPI environment variables
```bash
source /opt/intel/oneapi/setvars.sh
```
### 3. LoRA Finetune
Here, we provide example usages on different hardware. Please refer to the appropriate script based on your device:
##### Finetuning LLaMA2-7B on single Arc A770
```bash
bash lora_finetune_llama2_7b_arc_1_card.sh
```
##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1100
```bash
bash lora_finetune_llama2_7b_pvc_1100_1_card.sh
```
##### Finetuning LLaMA2-7B on single tile of Intel Data Center GPU Max 1550
```bash
bash lora_finetune_llama2_7b_pvc_1550_1_tile.sh
```
##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1550
```bash
bash lora_finetune_llama2_7b_pvc_1550_4_card.sh
```
### 4. (Optional) Resume Training
**If you fail to complete the whole finetuning process, it is suggested to resume training from a previously saved checkpoint by specifying `resume_from_checkpoint` to the local checkpoint folder as following:**
```bash
python ./alpaca_lora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
--resume_from_checkpoint "./bigdl-qlora-alpaca/checkpoint-1100"
```
### 5. Sample Output
```log
{'loss': 1.9231, 'learning_rate': 2.9999945367033285e-05, 'epoch': 0.0}
{'loss': 1.8622, 'learning_rate': 2.9999781468531096e-05, 'epoch': 0.01}
{'loss': 1.9043, 'learning_rate': 2.9999508305687345e-05, 'epoch': 0.01}
{'loss': 1.8967, 'learning_rate': 2.999912588049185e-05, 'epoch': 0.01}
{'loss': 1.9658, 'learning_rate': 2.9998634195730358e-05, 'epoch': 0.01}
{'loss': 1.8386, 'learning_rate': 2.9998033254984483e-05, 'epoch': 0.02}
{'loss': 1.809, 'learning_rate': 2.999732306263172e-05, 'epoch': 0.02}
{'loss': 1.8552, 'learning_rate': 2.9996503623845395e-05, 'epoch': 0.02}
1%|█ | 8/1164 [xx:xx<xx:xx:xx, xx s/it]
```
### 6. Merge the adapter into the original model
```
python ./export_merged_model.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --adapter_path ./outputs/checkpoint-200 --output_path ./outputs/checkpoint-200-merged
```
Then you can use `./outputs/checkpoint-200-merged` as a normal huggingface transformer model to do inference.
### 7. Troubleshooting
- If you fail to finetune on multi cards because of following error message:
```bash
RuntimeError: oneCCL: comm_selector.cpp:57 create_comm_impl: EXCEPTION: ze_data was not initialized
```
Please try `sudo apt install level-zero-dev` to fix it.

View file

@ -0,0 +1,267 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/tloen/alpaca-lora/blob/main/finetune.py
#
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List
import fire
import torch
import transformers
from datasets import load_dataset
import accelerate
from transformers import LlamaTokenizer
from peft import (
get_peft_model_state_dict,
set_peft_model_state_dict,
)
current_dir = os.path.dirname(os.path.realpath(__file__))
common_util_path = os.path.join(current_dir, '..')
import sys
sys.path.append(common_util_path)
from common.utils import Prompter, get_int_from_env, wandb_check, get_train_val_data
from transformers import BitsAndBytesConfig
from bigdl.llm.transformers import AutoModelForCausalLM
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
LoraConfig
from bigdl.llm.utils.common import invalidInputError
local_rank = get_int_from_env(["LOCAL_RANK","MPI_LOCALRANKID"], "0")
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
port = get_int_from_env(["MASTER_PORT"], 29500)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(local_rank)
os.environ["MASTER_PORT"] = str(port)
def train(
# model/data params
base_model: str = "meta-llama/Llama-2-7b-hf", # the only required argument, default to be "meta-llama/Llama-2-7b-hf"
saved_low_bit_model: str = None, # optional, the path to the saved model with bigdl-llm low-bit optimization
data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "./bigdl-qlora-alpaca",
# training hyperparams
bf16: bool = True, # default to bf16
batch_size: int = 128,
micro_batch_size: int = 2, # default to be 2, limited by GPU memory
num_epochs: int = 3,
learning_rate: float = 3e-5, # default to be 3e-5 to avoid divergence
cutoff_len: int = 256,
val_set_size: int = 2000,
# lora hyperparams
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"up_proj",
"down_proj",
"gate_proj"
],
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = False,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
gradient_checkpointing: bool = False,
deepspeed: str = None,
training_mode: str = "lora",
):
invalidInputError(training_mode == "lora",
f"This example is for lora training mode, but got training_mode={training_mode}.")
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training Alpaca-LoRA model with params:\n"
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"lora_r: {lora_r}\n"
f"lora_alpha: {lora_alpha}\n"
f"lora_dropout: {lora_dropout}\n"
f"lora_target_modules: {lora_target_modules}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"add_eos_token: {add_eos_token}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
f"training_mode: {training_mode}\n"
)
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size
prompter = Prompter(prompt_template_name)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
# Check if parameter passed or if set within environ
use_wandb = wandb_check(wandb_project, wandb_watch, wandb_log_model)
if saved_low_bit_model is not None:
# Load the low bit optimized model if provide the saved path
model = AutoModelForCausalLM.load_low_bit(
saved_low_bit_model,
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_low_bit="bf16",
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
)
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")
tokenizer = LlamaTokenizer.from_pretrained(base_model)
print(f"Tokenizer loaded on rank {os.environ.get('LOCAL_RANK')}")
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference
print(model)
# Prepare a BigDL-LLM compatible Peft model
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=gradient_checkpointing)
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
training_mode=training_mode,
)
print(f"Lora Config: {config}")
model = get_peft_model(model, config)
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_path)
else:
data = load_dataset(data_path)
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
train_data, val_data = get_train_val_data(data, tokenizer, prompter, train_on_inputs,
add_eos_token, cutoff_len, val_set_size, seed=42)
# Unused
# if not ddp and torch.cuda.device_count() > 1:
# # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
# model.is_parallelizable = True
# model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
# warmup_ratio=0.03,
# warmup_steps=100,
max_grad_norm=0.3,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
lr_scheduler_type="cosine",
bf16=True, # ensure training more stable
logging_steps=1,
optim="adamw_torch",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=100 if val_set_size > 0 else None,
save_steps=100,
output_dir=output_dir,
save_total_limit=100,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else None,
run_name=wandb_run_name if use_wandb else None,
gradient_checkpointing=gradient_checkpointing,
ddp_backend="ccl",
deepspeed=deepspeed,
save_safetensors=False,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
model.config.use_cache = False
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
model.save_pretrained(output_dir)
print(
"\n If there's a warning about missing keys above, please disregard :)"
)
if __name__ == "__main__":
fire.Fire(train)

View file

@ -0,0 +1,44 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import torch
from transformers import LlamaTokenizer # noqa: F402
import argparse
current_dir = os.path.dirname(os.path.realpath(__file__))
common_util_path = os.path.join(current_dir, '..')
import sys
sys.path.append(common_util_path)
from common.utils import merge_adapter
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Merge the adapter into the original model for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--adapter_path', type=str,)
parser.add_argument('--output_path', type=str,)
args = parser.parse_args()
base_model = model_path = args.repo_id_or_model_path
adapter_path = args.adapter_path
output_path = args.output_path
tokenizer = LlamaTokenizer.from_pretrained(base_model)
merge_adapter(base_model, tokenizer, adapter_path, output_path)
print(f'Finish to merge the adapter into the original model and you could find the merged model in {output_path}.')

View file

@ -15,12 +15,11 @@
#
# You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file
python ./alpaca_qlora_finetuning.py \
python ./alpaca_lora_finetuning.py \
--micro_batch_size 8 \
--batch_size 128 \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-lora-alpaca" \
--gradient_checkpointing True \
--lora_target_modules "['k_proj', 'q_proj', 'o_proj', 'v_proj']" \
--training_mode "lora"
--lora_target_modules "['k_proj', 'q_proj', 'o_proj', 'v_proj']"

View file

@ -20,12 +20,11 @@ export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
mpirun -n 4 \
python -u ./alpaca_qlora_finetuning.py \
python -u ./alpaca_lora_finetuning.py \
--micro_batch_size 8 \
--batch_size 128 \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-lora-alpaca" \
--gradient_checkpointing True \
--lora_target_modules "['k_proj', 'q_proj', 'o_proj', 'v_proj', 'up_proj', 'down_proj', 'gate_proj']" \
--training_mode "lora"
--lora_target_modules "['k_proj', 'q_proj', 'o_proj', 'v_proj', 'up_proj', 'down_proj', 'gate_proj']"

View file

@ -15,12 +15,11 @@
#
# You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file
python ./alpaca_qlora_finetuning.py \
python ./alpaca_lora_finetuning.py \
--micro_batch_size 8 \
--batch_size 128 \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-lora-alpaca" \
--gradient_checkpointing True \
--lora_target_modules "['k_proj', 'q_proj', 'o_proj', 'v_proj', 'up_proj', 'down_proj', 'gate_proj']" \
--training_mode "lora"
--lora_target_modules "['k_proj', 'q_proj', 'o_proj', 'v_proj', 'up_proj', 'down_proj', 'gate_proj']"

View file

@ -15,17 +15,16 @@
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=7
export OMP_NUM_THREADS=56
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
mpirun -n 8 \
python -u ./alpaca_qlora_finetuning.py \
python -u ./alpaca_lora_finetuning.py \
--micro_batch_size 8 \
--batch_size 128 \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-lora-alpaca" \
--gradient_checkpointing False \
--lora_target_modules "['k_proj', 'q_proj', 'o_proj', 'v_proj', 'up_proj', 'down_proj', 'gate_proj']" \
--training_mode "lora"
--lora_target_modules "['k_proj', 'q_proj', 'o_proj', 'v_proj', 'up_proj', 'down_proj', 'gate_proj']"

View file

@ -0,0 +1,84 @@
# QA-LoRA Finetuning with BigDL-LLM
This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM (using [QA-LoRA](https://arxiv.org/abs/2309.14717) algorithm) on [Intel GPU](../../README.md).
### 0. Requirements
To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../README.md#requirements) for more information.
### 1. Install
```bash
conda create -n llm python=3.9
conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
pip install transformers==4.34.0 datasets
pip install fire peft==0.5.0
pip install oneccl_bind_pt==2.1.100 -f https://developer.intel.com/ipex-whl-stable-xpu # necessary to run distributed finetuning
pip install accelerate==0.23.0
pip install bitsandbytes scipy
```
### 2. Configures OneAPI environment variables
```bash
source /opt/intel/oneapi/setvars.sh
```
### 3. QA-LoRA Finetune
Here, we provide example usages on different hardware. Please refer to the appropriate script based on your device:
##### Finetuning LLaMA2-7B on single Arc A770
```bash
bash qalora_finetune_llama2_7b_arc_1_card.sh
```
##### Finetuning LLaMA2-7B on two Arc A770
```bash
bash qalora_finetune_llama2_7b_arc_2_card.sh
```
##### Finetuning LLaMA2-7B on single tile of Intel Data Center GPU Max 1550
```bash
bash qalora_finetune_llama2_7b_pvc_1550_1_tile.sh
```
### 4. (Optional) Resume Training
**If you fail to complete the whole finetuning process, it is suggested to resume training from a previously saved checkpoint by specifying `resume_from_checkpoint` to the local checkpoint folder as following:**
```bash
python ./alpaca_qalora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
--resume_from_checkpoint "./bigdl-qlora-alpaca/checkpoint-1100"
```
### 5. Sample Output
```log
{'loss': 1.9231, 'learning_rate': 2.9999945367033285e-05, 'epoch': 0.0}
{'loss': 1.8622, 'learning_rate': 2.9999781468531096e-05, 'epoch': 0.01}
{'loss': 1.9043, 'learning_rate': 2.9999508305687345e-05, 'epoch': 0.01}
{'loss': 1.8967, 'learning_rate': 2.999912588049185e-05, 'epoch': 0.01}
{'loss': 1.9658, 'learning_rate': 2.9998634195730358e-05, 'epoch': 0.01}
{'loss': 1.8386, 'learning_rate': 2.9998033254984483e-05, 'epoch': 0.02}
{'loss': 1.809, 'learning_rate': 2.999732306263172e-05, 'epoch': 0.02}
{'loss': 1.8552, 'learning_rate': 2.9996503623845395e-05, 'epoch': 0.02}
1%|█ | 8/1164 [xx:xx<xx:xx:xx, xx s/it]
```
### 6. Merge the adapter into the original model
```
python ./export_merged_model.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --adapter_path ./outputs/checkpoint-200 --output_path ./outputs/checkpoint-200-merged
```
Then you can use `./outputs/checkpoint-200-merged` as a normal huggingface transformer model to do inference.
### 7. Troubleshooting
- If you fail to finetune on multi cards because of following error message:
```bash
RuntimeError: oneCCL: comm_selector.cpp:57 create_comm_impl: EXCEPTION: ze_data was not initialized
```
Please try `sudo apt install level-zero-dev` to fix it.

View file

@ -0,0 +1,279 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/tloen/alpaca-lora/blob/main/finetune.py
#
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List
import fire
import torch
import transformers
from datasets import load_dataset
import accelerate
from transformers import LlamaTokenizer
from peft import (
get_peft_model_state_dict,
set_peft_model_state_dict,
)
current_dir = os.path.dirname(os.path.realpath(__file__))
common_util_path = os.path.join(current_dir, '..')
import sys
sys.path.append(common_util_path)
from common.utils import Prompter, get_int_from_env, wandb_check, get_train_val_data
from transformers import BitsAndBytesConfig
from bigdl.llm.transformers import AutoModelForCausalLM
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
LoraConfig
from bigdl.llm.utils.common import invalidInputError
local_rank = get_int_from_env(["LOCAL_RANK","MPI_LOCALRANKID"], "0")
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
port = get_int_from_env(["MASTER_PORT"], 29500)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(local_rank)
os.environ["MASTER_PORT"] = str(port)
def train(
# model/data params
base_model: str = "meta-llama/Llama-2-7b-hf", # the only required argument, default to be "meta-llama/Llama-2-7b-hf"
saved_low_bit_model: str = None, # optional, the path to the saved model with bigdl-llm low-bit optimization
data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "./bigdl-qlora-alpaca",
# training hyperparams
bf16: bool = True, # default to bf16
batch_size: int = 128,
micro_batch_size: int = 2, # default to be 2, limited by GPU memory
num_epochs: int = 3,
learning_rate: float = 3e-5, # default to be 3e-5 to avoid divergence
cutoff_len: int = 256,
val_set_size: int = 2000,
# lora hyperparams
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"up_proj",
"down_proj",
"gate_proj"
],
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = False,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
gradient_checkpointing: bool = False,
deepspeed: str = None,
training_mode: str = "qalora",
):
invalidInputError(training_mode == "qalora",
f"This example is for qalora training mode, but got training_mode={training_mode}.")
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training Alpaca-LoRA model with params:\n"
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"lora_r: {lora_r}\n"
f"lora_alpha: {lora_alpha}\n"
f"lora_dropout: {lora_dropout}\n"
f"lora_target_modules: {lora_target_modules}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"add_eos_token: {add_eos_token}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
f"training_mode: {training_mode}\n"
)
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size
prompter = Prompter(prompt_template_name)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
# Check if parameter passed or if set within environ
use_wandb = wandb_check(wandb_project, wandb_watch, wandb_log_model)
if saved_low_bit_model is not None:
# Load the low bit optimized model if provide the saved path
model = AutoModelForCausalLM.load_low_bit(
saved_low_bit_model,
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
)
else:
# Default 4-bit format for qa-lora is sym_int4
# use bnb_config for qalora, which use 4bit for base model
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="int4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(base_model,
quantization_config=bnb_config, )
# below is also supported
# Load the base model from a directory or the HF Hub to 4-bit format
# model = AutoModelForCausalLM.from_pretrained(
# base_model,
# load_in_low_bit="sym_int4",
# optimize_model=False,
# torch_dtype=torch.bfloat16,
# # device_map=device_map,
# modules_to_not_convert=["lm_head"],
# )
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")
tokenizer = LlamaTokenizer.from_pretrained(base_model)
print(f"Tokenizer loaded on rank {os.environ.get('LOCAL_RANK')}")
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference
print(model)
# Prepare a BigDL-LLM compatible Peft model
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=gradient_checkpointing)
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
training_mode=training_mode,
)
print(f"Lora Config: {config}")
model = get_peft_model(model, config)
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_path)
else:
data = load_dataset(data_path)
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
train_data, val_data = get_train_val_data(data, tokenizer, prompter, train_on_inputs,
add_eos_token, cutoff_len, val_set_size, seed=42)
# Unused
# if not ddp and torch.cuda.device_count() > 1:
# # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
# model.is_parallelizable = True
# model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
# warmup_ratio=0.03,
# warmup_steps=100,
max_grad_norm=0.3,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
lr_scheduler_type="constant",
bf16=True, # ensure training more stable
logging_steps=1,
optim="adamw_torch",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=100 if val_set_size > 0 else None,
save_steps=100,
output_dir=output_dir,
save_total_limit=100,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else None,
run_name=wandb_run_name if use_wandb else None,
gradient_checkpointing=gradient_checkpointing,
ddp_backend="ccl",
deepspeed=deepspeed,
save_safetensors=False,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
model.config.use_cache = False
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
model.save_pretrained(output_dir)
print(
"\n If there's a warning about missing keys above, please disregard :)"
)
if __name__ == "__main__":
fire.Fire(train)

View file

@ -0,0 +1,44 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import torch
from transformers import LlamaTokenizer # noqa: F402
import argparse
current_dir = os.path.dirname(os.path.realpath(__file__))
common_util_path = os.path.join(current_dir, '..')
import sys
sys.path.append(common_util_path)
from common.utils import merge_adapter
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Merge the adapter into the original model for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--adapter_path', type=str,)
parser.add_argument('--output_path', type=str,)
args = parser.parse_args()
base_model = model_path = args.repo_id_or_model_path
adapter_path = args.adapter_path
output_path = args.output_path
tokenizer = LlamaTokenizer.from_pretrained(base_model)
merge_adapter(base_model, tokenizer, adapter_path, output_path)
print(f'Finish to merge the adapter into the original model and you could find the merged model in {output_path}.')

View file

@ -15,7 +15,7 @@
#
# You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file
python ./alpaca_qlora_finetuning.py \
python ./alpaca_qalora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
@ -25,5 +25,4 @@ python ./alpaca_qlora_finetuning.py \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--val_set_size 2000 \
--training_mode "qalora"
--val_set_size 2000

View file

@ -15,12 +15,12 @@
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=6 # adjust this to 1/4 of total physical cores
export OMP_NUM_THREADS=6
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
mpirun -n 2 \
python -u ./alpaca_qlora_finetuning.py \
python -u ./alpaca_qalora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
@ -30,5 +30,4 @@ mpirun -n 2 \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--val_set_size 2000 \
--training_mode "qalora" > training.log
--val_set_size 2000 > training.log

View file

@ -15,20 +15,19 @@
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=28 # adjust this to 1/4 of total physical cores
export OMP_NUM_THREADS=56
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
mpirun -n 2 \
python -u ./alpaca_qlora_finetuning.py \
python -u ./alpaca_qalora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
--training_mode "qalora" \
--learning_rate 9e-5 \
--micro_batch_size 8 \
--batch_size 128 \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--val_set_size 2000 > training.log
--val_set_size 2000 > training.log

View file

@ -16,7 +16,7 @@
# You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file
python ./alpaca_qlora_finetuning.py \
python ./alpaca_qalora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
@ -27,5 +27,4 @@ python ./alpaca_qlora_finetuning.py \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--val_set_size 2000 \
--training_mode "qalora"
--val_set_size 2000

View file

@ -0,0 +1,5 @@
# QLoRA Finetuning with BigDL-LLM
We provide [Alpaca-QLoRA example](./alpaca-qlora/), which ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM (using [QLoRA](https://arxiv.org/abs/2305.14314) algorithm) on [Intel GPU](../../README.md).
Meanwhile, we also provide a [simple example](./simple-example/) to help you get started with QLoRA Finetuning using BigDL-LLM.

View file

@ -1,9 +1,11 @@
# Alpaca Finetuning with BigDL-LLM
# QLoRA Finetuning with BigDL-LLM
This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM (using either [QLoRA](https://arxiv.org/abs/2305.14314) / [QA-LoRA](https://arxiv.org/abs/2309.14717) / [LoRA](https://arxiv.org/abs/2106.09685) or [ReLoRA](https://arxiv.org/abs/2307.05695) algorithm) on [Intel GPU](../../README.md).
This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM (using [QLoRA](https://arxiv.org/abs/2305.14314) algorithm) on [Intel GPU](../../../README.md).
> Note: You could also refer to [simple QLoRA example](../simple-example/) to try related usage.
### 0. Requirements
To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../README.md#requirements) for more information.
To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../../README.md#requirements) for more information.
### 1. Install
@ -17,6 +19,10 @@ pip install fire peft==0.5.0
pip install oneccl_bind_pt==2.1.100 -f https://developer.intel.com/ipex-whl-stable-xpu # necessary to run distributed finetuning
pip install accelerate==0.23.0
pip install bitsandbytes scipy
# configures OneAPI environment variables
source /opt/intel/oneapi/setvars.sh # necessary to run before installing deepspeed
pip install git+https://github.com/microsoft/DeepSpeed.git@78c518e
pip install git+https://github.com/intel/intel-extension-for-deepspeed.git@ec33277
```
### 2. Configures OneAPI environment variables
@ -24,131 +30,104 @@ pip install bitsandbytes scipy
source /opt/intel/oneapi/setvars.sh
```
### 3. Finetune
### 3. QLoRA Finetune
Now we support four training modes ([QLoRA](https://arxiv.org/abs/2305.14314) / [QA-LoRA](https://arxiv.org/abs/2309.14717) / [LoRA](https://arxiv.org/abs/2106.09685) / [ReLoRA](https://arxiv.org/abs/2307.05695)), to run different mode, just change `training_mode` to `qlora` / `qalora` / `lora` / `relora` in below script.
Here, we provide example usages on different hardware. Please refer to the appropriate script based on your device and model:
Here, we provide example usages on different hardware. Please refer to the appropriate script based on your device:
#### QLoRA
<details>
<summary> Show LLaMA2-7B examples </summary>
##### Finetuning LLaMA2-7B on single Arc A770
```bash
bash finetune_llama2_7b_arc_1_card.sh
bash qlora_finetune_llama2_7b_arc_1_card.sh
```
##### Finetuning LLaMA2-7B on two Arc A770
```bash
bash finetune_llama2_7b_arc_2_card.sh
bash qlora_finetune_llama2_7b_arc_2_card.sh
```
##### Finetuning LLaMA2-7B on single Data Center GPU Flex 170
```bash
bash finetune_llama2_7b_flex_170_1_card.sh
bash qlora_finetune_llama2_7b_flex_170_1_card.sh
```
##### Finetuning LLaMA2-7B on three Data Center GPU Flex 170
```bash
bash finetune_llama2_7b_flex_170_3_card.sh
bash qlora_finetune_llama2_7b_flex_170_3_card.sh
```
##### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1100
```bash
bash finetune_llama2_7b_pvc_1100_1_card.sh
bash qlora_finetune_llama2_7b_pvc_1100_1_card.sh
```
##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1100
```bash
bash finetune_llama2_7b_pvc_1100_4_card.sh
bash qlora_finetune_llama2_7b_pvc_1100_4_card.sh
```
##### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1550
```bash
bash finetune_llama2_7b_pvc_1550_1_card.sh
bash qlora_finetune_llama2_7b_pvc_1550_1_card.sh
```
##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1550
```bash
bash finetune_llama2_7b_pvc_1550_4_card.sh
bash qlora_finetune_llama2_7b_pvc_1550_4_card.sh
```
#### QA-LoRA
##### Finetuning LLaMA2-7B on single Arc A770
</details>
<details>
<summary> Show LLaMA2-13B examples </summary>
##### Finetuning LLaMA2-13B on single tile of Intel Data Center GPU Max 1550
```bash
bash qalora_finetune_llama2_7b_arc_1_card.sh
bash qlora_finetune_llama2_13b_pvc_1550_1_tile.sh
```
##### Finetuning LLaMA2-7B on two Arc A770
##### Finetuning LLaMA2-13B on single Intel Data Center GPU Max 1550
```bash
bash qalora_finetune_llama2_7b_arc_2_card.sh
bash qlora_finetune_llama2_13b_pvc_1550_1_card.sh
```
##### Finetuning LLaMA2-7B on single Tile Intel Data Center GPU Max 1550
##### Finetuning LLaMA2-13B on four Intel Data Center GPU Max 1550
```bash
bash qalora_finetune_llama2_7b_pvc_1550_1_tile.sh
bash qlora_finetune_llama2_13b_pvc_1550_4_card.sh
```
#### LoRA
</details>
##### Finetuning LLaMA2-7B on single Arc A770
<details>
<summary> Show LLaMA2-70B examples </summary>
Different from `LLaMA2-7B` and `LLaMA2-13B`, it is recommonded to save the model with bigdl-llm low-bit optimization first to avoid large amount of CPU memory usage. And DeepSpeed ZeRO2 technology is used during finetuning.
##### Finetuning LLaMA2-70B on one Intel Data Center GPU Max 1550
```bash
bash lora_finetune_llama2_7b_arc_1_card.sh
bash qlora_finetune_llama2_70b_pvc_1550_1_card.sh
```
##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1100
##### Finetuning LLaMA2-70B on four Intel Data Center GPU Max 1550
```bash
bash lora_finetune_llama2_7b_pvc_1100_1_card.sh
bash qlora_finetune_llama2_70b_pvc_1550_4_card.sh
```
##### Finetuning LLaMA2-7B on single Tile Intel Data Center GPU Max 1550
```bash
bash lora_finetune_llama2_7b_pvc_1550_1_tile.sh
```
##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1550
```bash
bash lora_finetune_llama2_7b_pvc_1550_4_card.sh
```
#### ReLoRA
##### Finetuning LLaMA2-7B on single Arc A770
```bash
bash relora_finetune_llama2_7b_arc_1_card.sh
```
##### Finetuning LLaMA2-7B on two Arc A770
```bash
bash relora_finetune_llama2_7b_arc_2_card.sh
```
##### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1550
```bash
bash relora_finetune_llama2_7b_pvc_1550_1_card.sh
```
##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1550
```bash
bash relora_finetune_llama2_7b_pvc_1550_4_card.sh
```
</details>
### 4. (Optional) Resume Training
If you fail to complete the whole finetuning process, it is suggested to resume training from a previously saved checkpoint by specifying `resume_from_checkpoint` to the local checkpoint folder as following:**
@ -173,14 +152,14 @@ python ./alpaca_qlora_finetuning.py \
1%|█ | 8/1164 [xx:xx<xx:xx:xx, xx s/it]
```
### 4. Merge the adapter into the original model
### 6. Merge the adapter into the original model
```
python ./export_merged_model.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --adapter_path ./outputs/checkpoint-200 --output_path ./outputs/checkpoint-200-merged
```
Then you can use `./outputs/checkpoint-200-merged` as a normal huggingface transformer model to do inference.
### 5. Troubleshooting
### 7. Troubleshooting
- If you fail to finetune on multi cards because of following error message:
```bash
RuntimeError: oneCCL: comm_selector.cpp:57 create_comm_impl: EXCEPTION: ze_data was not initialized

View file

@ -0,0 +1,279 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/tloen/alpaca-lora/blob/main/finetune.py
#
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List
import fire
import torch
import transformers
from datasets import load_dataset
import accelerate
from transformers import LlamaTokenizer
from peft import (
get_peft_model_state_dict,
set_peft_model_state_dict,
)
current_dir = os.path.dirname(os.path.realpath(__file__))
common_util_path = os.path.join(current_dir, '..', '..')
import sys
sys.path.append(common_util_path)
from common.utils import Prompter, get_int_from_env, wandb_check, get_train_val_data
from transformers import BitsAndBytesConfig
from bigdl.llm.transformers import AutoModelForCausalLM
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
LoraConfig
from bigdl.llm.utils.common import invalidInputError
local_rank = get_int_from_env(["LOCAL_RANK","MPI_LOCALRANKID"], "0")
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
port = get_int_from_env(["MASTER_PORT"], 29500)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(local_rank)
os.environ["MASTER_PORT"] = str(port)
def train(
# model/data params
base_model: str = "meta-llama/Llama-2-7b-hf", # the only required argument, default to be "meta-llama/Llama-2-7b-hf"
saved_low_bit_model: str = None, # optional, the path to the saved model with bigdl-llm low-bit optimization
data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "./bigdl-qlora-alpaca",
# training hyperparams
bf16: bool = True, # default to bf16
batch_size: int = 128,
micro_batch_size: int = 2, # default to be 2, limited by GPU memory
num_epochs: int = 3,
learning_rate: float = 3e-5, # default to be 3e-5 to avoid divergence
cutoff_len: int = 256,
val_set_size: int = 2000,
# lora hyperparams
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"up_proj",
"down_proj",
"gate_proj"
], # according to the QLoRA paper (https://arxiv.org/pdf/2305.14314.pdf), it's suggested to fine tune all linear layers
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = False,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
gradient_checkpointing: bool = False,
deepspeed: str = None,
training_mode: str = "qlora",
):
invalidInputError(training_mode == "qlora",
f"This example is for qlora training mode, but got training_mode={training_mode}.")
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training Alpaca-LoRA model with params:\n"
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"lora_r: {lora_r}\n"
f"lora_alpha: {lora_alpha}\n"
f"lora_dropout: {lora_dropout}\n"
f"lora_target_modules: {lora_target_modules}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"add_eos_token: {add_eos_token}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
f"training_mode: {training_mode}\n"
)
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size
prompter = Prompter(prompt_template_name)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
# Check if parameter passed or if set within environ
use_wandb = wandb_check(wandb_project, wandb_watch, wandb_log_model)
if saved_low_bit_model is not None:
# Load the low bit optimized model if provide the saved path
model = AutoModelForCausalLM.load_low_bit(
saved_low_bit_model,
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
)
else:
# According to the QLoRA paper, using "nf4" could yield better model quality than "int4"
# use bnb_config for qlora/qalora/relora, which use 4bit for base model
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(base_model,
quantization_config=bnb_config, )
# below is also supported
# Load the base model from a directory or the HF Hub to 4-bit format
# model = AutoModelForCausalLM.from_pretrained(
# base_model,
# load_in_low_bit="nf4",
# optimize_model=False,
# torch_dtype=torch.bfloat16,
# # device_map=device_map,
# modules_to_not_convert=["lm_head"],
# )
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")
tokenizer = LlamaTokenizer.from_pretrained(base_model)
print(f"Tokenizer loaded on rank {os.environ.get('LOCAL_RANK')}")
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference
print(model)
# Prepare a BigDL-LLM compatible Peft model
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=gradient_checkpointing)
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
training_mode=training_mode,
)
print(f"Lora Config: {config}")
model = get_peft_model(model, config)
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_path)
else:
data = load_dataset(data_path)
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
train_data, val_data = get_train_val_data(data, tokenizer, prompter, train_on_inputs,
add_eos_token, cutoff_len, val_set_size, seed=42)
# Unused
# if not ddp and torch.cuda.device_count() > 1:
# # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
# model.is_parallelizable = True
# model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
# warmup_ratio=0.03,
# warmup_steps=100,
max_grad_norm=0.3,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
lr_scheduler_type="cosine",
bf16=True, # ensure training more stable
logging_steps=1,
optim="adamw_torch",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=100 if val_set_size > 0 else None,
save_steps=100,
output_dir=output_dir,
save_total_limit=100,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else None,
run_name=wandb_run_name if use_wandb else None,
gradient_checkpointing=gradient_checkpointing,
ddp_backend="ccl",
deepspeed=deepspeed,
save_safetensors=False,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
model.config.use_cache = False
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
model.save_pretrained(output_dir)
print(
"\n If there's a warning about missing keys above, please disregard :)"
)
if __name__ == "__main__":
fire.Fire(train)

View file

@ -0,0 +1,16 @@
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
},
"bp16": {
"enabled": true
},
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto"
}

View file

@ -0,0 +1,44 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import torch
from transformers import LlamaTokenizer # noqa: F402
import argparse
current_dir = os.path.dirname(os.path.realpath(__file__))
common_util_path = os.path.join(current_dir, '..', '..')
import sys
sys.path.append(common_util_path)
from common.utils import merge_adapter
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Merge the adapter into the original model for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--adapter_path', type=str,)
parser.add_argument('--output_path', type=str,)
args = parser.parse_args()
base_model = model_path = args.repo_id_or_model_path
adapter_path = args.adapter_path
output_path = args.output_path
tokenizer = LlamaTokenizer.from_pretrained(base_model)
merge_adapter(base_model, tokenizer, adapter_path, output_path)
print(f'Finish to merge the adapter into the original model and you could find the merged model in {output_path}.')

View file

@ -0,0 +1,28 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=56
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
mpirun -n 2 \
python -u ./alpaca_qlora_finetuning.py \
--base_model "meta-llama/Llama-2-13b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
--micro_batch_size 8 \
--batch_size 128 > training.log

View file

@ -0,0 +1,23 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file
python ./alpaca_qlora_finetuning.py \
--base_model "meta-llama/Llama-2-13b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
--micro_batch_size 8 \
--batch_size 128

View file

@ -0,0 +1,28 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=56
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
mpirun -n 8 \
python -u ./alpaca_qlora_finetuning.py \
--base_model "meta-llama/Llama-2-13b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
--micro_batch_size 8 \
--batch_size 128 > training.log

View file

@ -0,0 +1,36 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# save Llama-2-70b-hf model with bigdl-llm low-bit optimization first
python save_low_bit_70b_model.py --output_path "./llama-2-70b-hf-nf4"
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=56
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
export CCL_ZE_IPC_EXCHANGE=sockets
mpirun -n 2 \
python -u ./alpaca_qlora_finetuning.py \
--base_model "meta-llama/Llama-2-70b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
--gradient_checkpointing True \
--micro_batch_size 8 \
--batch_size 128 \
--deepspeed ./deepspeed_zero2.json \
--saved_low_bit_model ./llama-2-70b-hf-nf4 > training.log

View file

@ -0,0 +1,36 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# save Llama-2-70b-hf model with bigdl-llm low-bit optimization first
python save_low_bit_70b_model.py --output_path "./llama-2-70b-hf-nf4"
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=56
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
export CCL_ZE_IPC_EXCHANGE=sockets
mpirun -n 8 \
python -u ./alpaca_qlora_finetuning.py \
--base_model "meta-llama/Llama-2-70b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
--gradient_checkpointing True \
--micro_batch_size 8 \
--batch_size 128 \
--deepspeed ./deepspeed_zero2.json \
--saved_low_bit_model ./llama-2-70b-hf-nf4 > training.log

View file

@ -15,7 +15,7 @@
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=6 # adjust this to 1/4 of total physical cores
export OMP_NUM_THREADS=6
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi

View file

@ -15,7 +15,7 @@
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=12 # adjust this to 1/4 of total physical cores
export OMP_NUM_THREADS=12
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi

View file

@ -15,7 +15,7 @@
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=28 # adjust this to 1/4 of total physical cores
export OMP_NUM_THREADS=28
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi

View file

@ -15,7 +15,7 @@
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=28 # adjust this to 1/4 of total physical cores
export OMP_NUM_THREADS=56
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi

View file

@ -15,7 +15,7 @@
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=28 # adjust this to 1/4 of total physical cores
export OMP_NUM_THREADS=56
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi

View file

@ -0,0 +1,45 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import LlamaTokenizer
from bigdl.llm.transformers import AutoModelForCausalLM
import torch
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Save model with bigdl-llm low-bit optimization')
parser.add_argument('--base_model', type=str, default="meta-llama/Llama-2-70b-hf",
help='The huggingface repo id for the Llama2-70B model to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--output_path', type=str, default="./llama-2-70b-hf-nf4",
help='The path to the saved model.')
args = parser.parse_args()
base_model = args.base_model
output_path = args.output_path
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_low_bit="nf4",
# load_in_4bit=True,
optimize_model=False,
torch_dtype=torch.bfloat16,
# device_map=device_map,
modules_to_not_convert=["lm_head"],
)
model.save_low_bit(output_path)
print(f'Model with bigdl-llm low-bit optimization is saved to {output_path}.')

View file

@ -1,9 +1,10 @@
# Finetuning LLAMA Using Q-Lora (experimental support)
# Simple Example of QLoRA Finetuning with BigDL-LLM
This example demonstrates how to finetune a llama2-7b model use Big-LLM 4bit optimizations using [Intel GPUs](../README.md).
This simple example demonstrates how to finetune a llama2-7b model use BigDL-LLM 4bit optimizations using [Intel GPUs](../../../README.md).
Note, this example is just used for illustrating related usage and don't guarantee convergence of training.
## 0. Requirements
To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../../README.md#requirements) for more information.
## Example: Finetune llama2-7b using qlora

View file

@ -0,0 +1,44 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import torch
from transformers import LlamaTokenizer # noqa: F402
import argparse
current_dir = os.path.dirname(os.path.realpath(__file__))
common_util_path = os.path.join(current_dir, '..', '..')
import sys
sys.path.append(common_util_path)
from common.utils import merge_adapter
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Merge the adapter into the original model for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--adapter_path', type=str,)
parser.add_argument('--output_path', type=str,)
args = parser.parse_args()
base_model = model_path = args.repo_id_or_model_path
adapter_path = args.adapter_path
output_path = args.output_path
tokenizer = LlamaTokenizer.from_pretrained(base_model)
merge_adapter(base_model, tokenizer, adapter_path, output_path)
print(f'Finish to merge the adapter into the original model and you could find the merged model in {output_path}.')

View file

@ -28,7 +28,7 @@ import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
parser = argparse.ArgumentParser(description='Simple example of how to qlora finetune llama2 model using bigdl-llm')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')

View file

@ -0,0 +1,9 @@
# Running LLM Finetuning using BigDL-LLM on Intel GPU
This folder contains examples of running different training mode with BigDL-LLM on Intel GPU:
- [LoRA](LoRA): examples of running LoRA finetuning
- [QLoRA](QLoRA): examples of running QLoRA finetuning
- [QA-LoRA](QA-LoRA): examples of running QA-LoRA finetuning
- [ReLora](ReLora): examples of running ReLora finetuning
- [common](common): common templates and utility classes in finetuning examples

View file

@ -0,0 +1,90 @@
# ReLoRA Finetuning with BigDL-LLM
This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM (using [ReLoRA](https://arxiv.org/abs/2307.05695) algorithm) on [Intel GPU](../../README.md).
### 0. Requirements
To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../README.md#requirements) for more information.
### 1. Install
```bash
conda create -n llm python=3.9
conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
pip install transformers==4.34.0 datasets
pip install fire peft==0.5.0
pip install oneccl_bind_pt==2.1.100 -f https://developer.intel.com/ipex-whl-stable-xpu # necessary to run distributed finetuning
pip install accelerate==0.23.0
pip install bitsandbytes scipy
```
### 2. Configures OneAPI environment variables
```bash
source /opt/intel/oneapi/setvars.sh
```
### 3. ReLoRA Finetune
Here, we provide example usages on different hardware. Please refer to the appropriate script based on your device:
##### Finetuning LLaMA2-7B on single Arc A770
```bash
bash relora_finetune_llama2_7b_arc_1_card.sh
```
##### Finetuning LLaMA2-7B on two Arc A770
```bash
bash relora_finetune_llama2_7b_arc_2_card.sh
```
##### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1550
```bash
bash relora_finetune_llama2_7b_pvc_1550_1_card.sh
```
##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1550
```bash
bash relora_finetune_llama2_7b_pvc_1550_4_card.sh
```
### 4. (Optional) Resume Training
**If you fail to complete the whole finetuning process, it is suggested to resume training from a previously saved checkpoint by specifying `resume_from_checkpoint` to the local checkpoint folder as following:**
```bash
python ./alpaca_relora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
--resume_from_checkpoint "./bigdl-qlora-alpaca/checkpoint-1100"
```
### 5. Sample Output
```log
{'loss': 1.9231, 'learning_rate': 2.9999945367033285e-05, 'epoch': 0.0}
{'loss': 1.8622, 'learning_rate': 2.9999781468531096e-05, 'epoch': 0.01}
{'loss': 1.9043, 'learning_rate': 2.9999508305687345e-05, 'epoch': 0.01}
{'loss': 1.8967, 'learning_rate': 2.999912588049185e-05, 'epoch': 0.01}
{'loss': 1.9658, 'learning_rate': 2.9998634195730358e-05, 'epoch': 0.01}
{'loss': 1.8386, 'learning_rate': 2.9998033254984483e-05, 'epoch': 0.02}
{'loss': 1.809, 'learning_rate': 2.999732306263172e-05, 'epoch': 0.02}
{'loss': 1.8552, 'learning_rate': 2.9996503623845395e-05, 'epoch': 0.02}
1%|█ | 8/1164 [xx:xx<xx:xx:xx, xx s/it]
```
### 6. Merge the adapter into the original model
```
python ./export_merged_model.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --adapter_path ./outputs/checkpoint-200 --output_path ./outputs/checkpoint-200-merged
```
Then you can use `./outputs/checkpoint-200-merged` as a normal huggingface transformer model to do inference.
### 7. Troubleshooting
- If you fail to finetune on multi cards because of following error message:
```bash
RuntimeError: oneCCL: comm_selector.cpp:57 create_comm_impl: EXCEPTION: ze_data was not initialized
```
Please try `sudo apt install level-zero-dev` to fix it.

View file

@ -44,29 +44,20 @@ from peft import (
get_peft_model_state_dict,
set_peft_model_state_dict,
)
from utils.prompter import Prompter
current_dir = os.path.dirname(os.path.realpath(__file__))
common_util_path = os.path.join(current_dir, '..')
import sys
sys.path.append(common_util_path)
from common.utils import Prompter, get_int_from_env, wandb_check, get_train_val_data
from transformers import BitsAndBytesConfig
from bigdl.llm.transformers import AutoModelForCausalLM
from bigdl.llm.transformers.relora import ReLoRATrainer
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
LoraConfig
from bigdl.llm.utils.common import invalidInputError
def get_int_from_env(env_keys, default):
"""Returns the first positive env value found in the `env_keys` list or the default."""
for e in env_keys:
val = int(os.environ.get(e, -1))
if val >= 0:
return val
return int(default)
def _get_trainer_cls(training_mode):
if training_mode == "relora":
from bigdl.llm.transformers.relora import ReLoRATrainer
return ReLoRATrainer
return transformers.Trainer
local_rank = get_int_from_env(["LOCAL_RANK","MPI_LOCALRANKID"], "0")
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
@ -102,7 +93,7 @@ def train(
"up_proj",
"down_proj",
"gate_proj"
], # according to the QLoRA paper (https://arxiv.org/pdf/2305.14314.pdf), it's suggested to fine tune all linear layers
],
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = False,
@ -116,7 +107,7 @@ def train(
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
gradient_checkpointing: bool = False,
deepspeed: str = None,
training_mode: str = "qlora",
training_mode: str = "relora",
# relora params, relora_steps should > 0 if the training mode is `relora`,
# Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695,
# minus the initial full fine-tune.
@ -124,8 +115,8 @@ def train(
relora_warmup_steps: int = 10, # Number of per-restart warmup steps
relora_cpu_offload: bool = True, # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
):
invalidInputError(training_mode in ["qlora", "qalora", "lora", "relora"],
"Only qlora / qalora / lora / relora are supported for training_mode now.")
invalidInputError(training_mode == "relora",
f"This example is for relora training mode, but got training_mode={training_mode}.")
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training Alpaca-LoRA model with params:\n"
@ -174,16 +165,7 @@ def train(
gradient_accumulation_steps = gradient_accumulation_steps // world_size
# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or (
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
os.environ["WANDB_PROJECT"] = wandb_project
if len(wandb_watch) > 0:
os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
use_wandb = wandb_check(wandb_project, wandb_watch, wandb_log_model)
if saved_low_bit_model is not None:
# Load the low bit optimized model if provide the saved path
@ -194,42 +176,20 @@ def train(
modules_to_not_convert=["lm_head"],
)
else:
# According to the QLoRA paper, using "nf4" could yield better model quality than "int4"
# Default 4-bit format for qa-lora is sym_int4
if training_mode == "lora":
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_low_bit="bf16",
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
)
else:
# use bnb_config for qlora/qalora/relora, which use 4bit for base model
if training_mode == "qalora":
low_bit_format = "int4"
else:
low_bit_format = "nf4"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type=low_bit_format,
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(base_model,
quantization_config=bnb_config, )
# use bnb_config for qlora/qalora/relora, which use 4bit for base model
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(base_model,
quantization_config=bnb_config, )
# below is also supported
# Load the base model from a directory or the HF Hub to 4-bit format
# if training_mode == "qalora":
# low_bit_format = "sym_int4"
# elif training_mode == "lora":
# low_bit_format = "bf16"
# else:
# low_bit_format = "nf4"
# model = AutoModelForCausalLM.from_pretrained(
# base_model,
# load_in_low_bit=low_bit_format,
# load_in_low_bit="nf4",
# optimize_model=False,
# torch_dtype=torch.bfloat16,
# # device_map=device_map,
@ -249,54 +209,6 @@ def train(
print(model)
def tokenize(prompt, add_eos_token=True):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
user_prompt = prompter.generate_prompt(
data_point["instruction"], data_point["input"]
)
tokenized_user_prompt = tokenize(
user_prompt, add_eos_token=add_eos_token
)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
if add_eos_token:
user_prompt_len -= 1
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
# Prepare a BigDL-LLM compatible Peft model
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=gradient_checkpointing)
@ -319,19 +231,8 @@ def train(
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
if val_set_size > 0:
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=42
)
train_data = (
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
)
val_data = (
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = None
train_data, val_data = get_train_val_data(data, tokenizer, prompter, train_on_inputs,
add_eos_token, cutoff_len, val_set_size, seed=42)
# Unused
# if not ddp and torch.cuda.device_count() > 1:
@ -339,7 +240,6 @@ def train(
# model.is_parallelizable = True
# model.model_parallel = True
trainer_cls = _get_trainer_cls(training_mode=training_mode)
extra_args = {}
if training_mode == "relora":
extra_args["base_model"] = base_model
@ -348,7 +248,7 @@ def train(
extra_args["relora_cpu_offload"] = relora_cpu_offload
extra_args["resume_from_checkpoint"] = resume_from_checkpoint
trainer = trainer_cls(
trainer = ReLoRATrainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
@ -361,7 +261,7 @@ def train(
max_grad_norm=0.3,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
lr_scheduler_type="constant" if training_mode=="qalora" else "cosine",
lr_scheduler_type="cosine",
bf16=True, # ensure training more stable
logging_steps=1,
optim="adamw_torch",
@ -370,7 +270,7 @@ def train(
eval_steps=100 if val_set_size > 0 else None,
save_steps=100,
output_dir=output_dir,
save_total_limit=100 if training_mode != "relora" else 4, # relora will save the whole model, here we use 4 to save the disk space.
save_total_limit=4, # relora will save the whole model, here we use 4 to save the disk space.
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,

View file

@ -0,0 +1,44 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import torch
from transformers import LlamaTokenizer # noqa: F402
import argparse
current_dir = os.path.dirname(os.path.realpath(__file__))
common_util_path = os.path.join(current_dir, '..')
import sys
sys.path.append(common_util_path)
from common.utils import merge_adapter
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Merge the adapter into the original model for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--adapter_path', type=str,)
parser.add_argument('--output_path', type=str,)
args = parser.parse_args()
base_model = model_path = args.repo_id_or_model_path
adapter_path = args.adapter_path
output_path = args.output_path
tokenizer = LlamaTokenizer.from_pretrained(base_model)
merge_adapter(base_model, tokenizer, adapter_path, output_path)
print(f'Finish to merge the adapter into the original model and you could find the merged model in {output_path}.')

View file

@ -15,10 +15,9 @@
#
# You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file
python ./alpaca_qlora_finetuning.py \
python ./alpaca_relora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-relora-alpaca" \
--relora_steps 300 \
--relora_warmup_steps 10 \
--training_mode "relora"
--relora_warmup_steps 10

View file

@ -15,15 +15,14 @@
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=6 # adjust this to 1/4 of total physical cores
export OMP_NUM_THREADS=6
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
mpirun -n 2 \
python -u ./alpaca_qlora_finetuning.py \
python -u ./alpaca_relora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-relora-alpaca" \
--relora_steps 300 \
--relora_warmup_steps 10 \
--training_mode "relora" > training.log
--relora_warmup_steps 10 > training.log

View file

@ -15,17 +15,16 @@
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=28 # adjust this to 1/4 of total physical cores
export OMP_NUM_THREADS=56
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
mpirun -n 2 \
python -u ./alpaca_qlora_finetuning.py \
python -u ./alpaca_relora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-relora-alpaca" \
--micro_batch_size 8 \
--relora_steps 300 \
--relora_warmup_steps 10 \
--batch_size 128 \
--training_mode "relora" > relora_training.log
--batch_size 128 > relora_training.log

View file

@ -15,17 +15,16 @@
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=28 # adjust this to 1/4 of total physical cores
export OMP_NUM_THREADS=56
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
mpirun -n 8 \
python -u ./alpaca_qlora_finetuning.py \
python -u ./alpaca_relora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-relora-alpaca" \
--micro_batch_size 8 \
--relora_steps 300 \
--relora_warmup_steps 10 \
--batch_size 128 \
--training_mode "relora" > relora_training.log
--batch_size 128 > relora_training.log

View file

@ -0,0 +1,18 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .prompter import Prompter
from .util import *

View file

@ -45,7 +45,9 @@ class Prompter(object):
if not template_name:
# Enforce the default here, so the constructor can be called with '' and will not break.
template_name = "alpaca"
file_name = osp.join("templates", f"{template_name}.json")
current_dir = osp.dirname(osp.realpath(__file__))
common_util_path = osp.join(current_dir, '..')
file_name = osp.join(common_util_path, "templates", f"{template_name}.json")
if not osp.exists(file_name):
invalidInputError(False, f"Can't read {file_name}")
with open(file_name) as fp:

View file

@ -0,0 +1,213 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/tloen/alpaca-lora/blob/main/finetune.py
#
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from https://github.com/tloen/alpaca-lora/blob/main/export_hf_checkpoint.py
#
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import transformers
def get_int_from_env(env_keys, default):
"""Returns the first positive env value found in the `env_keys` list or the default."""
for e in env_keys:
val = int(os.environ.get(e, -1))
if val >= 0:
return val
return int(default)
def wandb_check(wandb_project, wandb_watch, wandb_log_model):
"""Check if wandb related parameter passed or if set within environ"""
use_wandb = len(wandb_project) > 0 or (
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
os.environ["WANDB_PROJECT"] = wandb_project
if len(wandb_watch) > 0:
os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
return use_wandb
def get_train_val_data(data, tokenizer, prompter, train_on_inputs,
add_eos_token, cutoff_len, val_set_size, seed=42):
"""Data processing to get train data and val data"""
def tokenize(prompt, add_eos_token=True):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
user_prompt = prompter.generate_prompt(
data_point["instruction"], data_point["input"]
)
tokenized_user_prompt = tokenize(
user_prompt, add_eos_token=add_eos_token
)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
if add_eos_token:
user_prompt_len -= 1
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
if val_set_size > 0:
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=seed
)
train_data = (
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
)
val_data = (
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = None
return train_data, val_data
def merge_adapter(base_model, tokenizer, adapter_path, output_path):
"""Merge the adapter into the original model and save"""
import torch
from bigdl.llm.transformers.qlora import PeftModel, LoraConfig
from bigdl.llm.transformers import AutoModelForCausalLM
from bigdl.llm.transformers.low_bit_linear import get_block_size
import tempfile
import shutil
lora_config = LoraConfig.from_json_file(os.path.join(adapter_path, "adapter_config.json"))
training_mode = lora_config.get("training_mode", "qlora")
qa_lora = training_mode == "qalora"
temp_dir = None
if qa_lora:
# Convert the qa-lora adapter to the correct shapes
# The default 4-bit format for qa_lora is sym_int4
block_size = get_block_size("sym_int4")
temp_dir = tempfile.TemporaryDirectory()
tmpdirname = os.path.join(temp_dir.name, "adapter")
try:
shutil.copytree(adapter_path, tmpdirname)
except Exception as e:
print(f"Failed to copy adapter dir, error: {e}")
mid_lora_path = os.path.join(tmpdirname, "adapter_model.bin")
adapter_path = os.path.join(adapter_path, "adapter_model.bin")
lora = torch.load(adapter_path, map_location='cpu')
# Get lora_a names
tmp_keys = [key for key in lora.keys() if 'lora_A' in key]
for tmp_key in tmp_keys:
lora_a = lora[tmp_key] / block_size
lora[tmp_key] = torch.repeat_interleave(lora_a, block_size, dim=1)
torch.save(lora, mid_lora_path)
adapter_path = tmpdirname
try:
base_model = AutoModelForCausalLM.from_pretrained(
base_model,
# load_in_low_bit="nf4", # should load the orignal model
torch_dtype=torch.float16,
device_map={"": "cpu"},
)
lora_model = PeftModel.from_pretrained(
base_model,
adapter_path,
device_map={"": "cpu"},
torch_dtype=torch.float16,
)
# merge weights - new merging method from peft
lora_model = lora_model.merge_and_unload()
lora_model.train(False)
lora_model_sd = lora_model.state_dict()
deloreanized_sd = {
k.replace("base_model.model.", ""): v
for k, v in lora_model_sd.items()
if "lora" not in k
}
base_model.save_pretrained(output_path, state_dict=deloreanized_sd)
tokenizer.save_pretrained(output_path)
except Exception as e:
print(f"Failed to merge the adapter, error: {e}.")
finally:
if qa_lora and temp_dir:
temp_dir.cleanup()

View file

@ -1,119 +0,0 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file is adapted from https://github.com/tloen/alpaca-lora/blob/main/export_hf_checkpoint.py
#
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from transformers import LlamaTokenizer # noqa: F402
from bigdl.llm.transformers.qlora import PeftModel, LoraConfig
from bigdl.llm.transformers import AutoModelForCausalLM
from bigdl.llm.transformers.low_bit_linear import get_block_size
import argparse
import tempfile
import shutil
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--adapter_path', type=str,)
parser.add_argument('--output_path', type=str,)
args = parser.parse_args()
base_model = model_path = args.repo_id_or_model_path
adapter_path = args.adapter_path
tokenizer = LlamaTokenizer.from_pretrained(base_model)
lora_config = LoraConfig.from_json_file(os.path.join(adapter_path, "adapter_config.json"))
training_mode = lora_config.get("training_mode", "qlora")
qa_lora = training_mode == "qalora"
temp_dir = None
if qa_lora:
# Convert the qa-lora adapter to the correct shapes
# The default 4-bit format for qa_lora is sym_int4
block_size = get_block_size("sym_int4")
temp_dir = tempfile.TemporaryDirectory()
tmpdirname = os.path.join(temp_dir.name, "adapter")
try:
shutil.copytree(adapter_path, tmpdirname)
except Exception as e:
print(f"Failed to copy adapter dir, error: {e}")
mid_lora_path = os.path.join(tmpdirname, "adapter_model.bin")
adapter_path = os.path.join(adapter_path, "adapter_model.bin")
lora = torch.load(adapter_path, map_location='cpu')
# Get lora_a names
tmp_keys = [key for key in lora.keys() if 'lora_A' in key]
for tmp_key in tmp_keys:
lora_a = lora[tmp_key] / block_size
lora[tmp_key] = torch.repeat_interleave(lora_a, block_size, dim=1)
torch.save(lora, mid_lora_path)
adapter_path = tmpdirname
try:
base_model = AutoModelForCausalLM.from_pretrained(
base_model,
# load_in_low_bit="nf4", # should load the orignal model
torch_dtype=torch.float16,
device_map={"": "cpu"},
)
lora_model = PeftModel.from_pretrained(
base_model,
adapter_path,
device_map={"": "cpu"},
torch_dtype=torch.float16,
)
# merge weights - new merging method from peft
lora_model = lora_model.merge_and_unload()
lora_model.train(False)
lora_model_sd = lora_model.state_dict()
deloreanized_sd = {
k.replace("base_model.model.", ""): v
for k, v in lora_model_sd.items()
if "lora" not in k
}
base_model.save_pretrained(args.output_path, state_dict=deloreanized_sd)
tokenizer.save_pretrained(args.output_path)
except Exception as e:
print(f"Failed to merge the adapter, error: {e}.")
finally:
if qa_lora and temp_dir:
temp_dir.cleanup()

View file

@ -3,7 +3,7 @@
This folder contains examples of running BigDL-LLM on Intel GPU:
- [HF-Transformers-AutoModels](HF-Transformers-AutoModels): running any ***Hugging Face Transformers*** model on BigDL-LLM (using the standard AutoModel APIs)
- [QLoRA-FineTuning](QLoRA-FineTuning): running ***QLoRA finetuning*** using BigDL-LLM on Intel GPUs
- [LLM-Finetuning](LLM-Finetuning): running ***finetuning*** (such as LoRA, QLoRA, QA-LoRA, etc) using BigDL-LLM on Intel GPUs
- [vLLM-Serving](vLLM-Serving): running ***vLLM*** serving framework on intel GPUs (with BigDL-LLM low-bit optimized models)
- [Deepspeed-AutoTP](Deepspeed-AutoTP): running distributed inference using ***DeepSpeed AutoTP*** (with BigDL-LLM low-bit optimized models) on Intel GPUs
- [PyTorch-Models](PyTorch-Models): running any PyTorch model on BigDL-LLM (with "one-line code change")

View file

@ -8,13 +8,13 @@ echo "# Start testing qlora fine-tuning"
start=$(date "+%s")
sed -i 's/max_steps=200/max_steps=2/; s/save_steps=100/save_steps=2/; s/logging_steps=20/logging_steps=1/' \
${ANALYTICS_ZOO_ROOT}/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py
${ANALYTICS_ZOO_ROOT}/python/llm/example/GPU/LLM-Finetuning/QLoRA/simple-example/qlora_finetuning.py
python ${ANALYTICS_ZOO_ROOT}/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py \
python ${ANALYTICS_ZOO_ROOT}/python/llm/example/GPU/LLM-Finetuning/QLoRA/simple-example/qlora_finetuning.py \
--repo-id-or-model-path ${LLAMA2_7B_ORIGIN_PATH} \
--dataset ${ABIRATE_ENGLISH_QUOTES_PATH}
python ${ANALYTICS_ZOO_ROOT}/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py \
python ${ANALYTICS_ZOO_ROOT}/python/llm/example/GPU/LLM-Finetuning/QLoRA/simple-example/export_merged_model.py \
--repo-id-or-model-path ${LLAMA2_7B_ORIGIN_PATH} \
--adapter_path ${PWD}/outputs/checkpoint-2 \
--output_path ${PWD}/outputs/checkpoint-2-merged