LLM : Add qlora cpu finetune docker image (#9271)
* init qlora cpu docker image * update * remove ipex and update * update * update readme * update example and readme
This commit is contained in:
parent
d109275333
commit
0f78ebe35e
4 changed files with 273 additions and 0 deletions
39
docker/llm/finetune/qlora/cpu/docker/Dockerfile
Normal file
39
docker/llm/finetune/qlora/cpu/docker/Dockerfile
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
FROM intel/oneapi-basekit:2023.2.1-devel-ubuntu22.04
|
||||
ARG http_proxy
|
||||
ARG https_proxy
|
||||
ENV TZ=Asia/Shanghai
|
||||
ARG PIP_NO_CACHE_DIR=false
|
||||
ENV TRANSFORMERS_COMMIT_ID=95fe0f5
|
||||
|
||||
# retrive oneapi repo public key
|
||||
RUN curl -fsSL https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2023.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \
|
||||
echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " > /etc/apt/sources.list.d/oneAPI.list
|
||||
|
||||
# update dependencies
|
||||
RUN apt-get update && \
|
||||
# install basic dependencies
|
||||
apt-get install -y curl wget git gnupg gpg-agent software-properties-common libunwind8-dev vim less && \
|
||||
# install python 3.9
|
||||
ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && \
|
||||
env DEBIAN_FRONTEND=noninteractive apt-get update && \
|
||||
add-apt-repository ppa:deadsnakes/ppa -y && \
|
||||
apt-get install -y python3.9 && \
|
||||
rm /usr/bin/python3 && \
|
||||
ln -s /usr/bin/python3.9 /usr/bin/python3 && \
|
||||
ln -s /usr/bin/python3 /usr/bin/python && \
|
||||
apt-get install -y python3-pip python3.9-dev python3-wheel python3.9-distutils && \
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
|
||||
# install torch and oneccl to reduce bigdl-llm size
|
||||
RUN pip3 install --upgrade pip && \
|
||||
export PIP_DEFAULT_TIMEOUT=100 && \
|
||||
pip install --upgrade torch==2.0.1 --index-url https://download.pytorch.org/whl/cpu && \
|
||||
pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable && \
|
||||
# install CPU bigdl-llm
|
||||
pip install --pre --upgrade bigdl-llm[all] -i https://pypi.tuna.tsinghua.edu.cn/simple/ && \
|
||||
# install huggingface dependencies
|
||||
pip install transformers==4.34.0 && \
|
||||
pip install peft==0.5.0 datasets
|
||||
|
||||
ADD ./qlora_finetuning_cpu.py /qlora_finetuning_cpu.py
|
||||
ADD ./start-qlora-finetuning-on-cpu.sh /start-qlora-finetuning-on-cpu.sh
|
||||
129
docker/llm/finetune/qlora/cpu/docker/README.md
Normal file
129
docker/llm/finetune/qlora/cpu/docker/README.md
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
## Fine-tune LLM with BigDL LLM Container
|
||||
|
||||
The following shows how to fine-tune LLM with Quantization (QLoRA built on BigDL-LLM 4bit optimizations) in a docker environment, which is accelerated by Intel CPU.
|
||||
|
||||
### 1. Prepare Docker Image
|
||||
|
||||
You can download directly from Dockerhub like:
|
||||
|
||||
```bash
|
||||
docker pull intelanalytics/bigdl-llm-finetune-qlora-cpu:2.4.0-SNAPSHOT
|
||||
```
|
||||
|
||||
Or build the image from source:
|
||||
|
||||
```bash
|
||||
export HTTP_PROXY=your_http_proxy
|
||||
export HTTPS_PROXY=your_https_proxy
|
||||
|
||||
docker build \
|
||||
--build-arg http_proxy=${HTTP_PROXY} \
|
||||
--build-arg https_proxy=${HTTPS_PROXY} \
|
||||
-t intelanalytics/bigdl-llm-finetune-qlora-cpu:2.4.0-SNAPSHOT \
|
||||
-f ./Dockerfile .
|
||||
```
|
||||
|
||||
### 2. Prepare Base Model, Data and Container
|
||||
|
||||
Here, we try to fine-tune a [Llama2-7b](https://huggingface.co/meta-llama/Llama-2-7b) with [English Quotes](https://huggingface.co/datasets/Abirate/english_quotes) dataset, and please download them and start a docker container with files mounted like below:
|
||||
|
||||
```bash
|
||||
export BASE_MODE_PATH=your_downloaded_base_model_path
|
||||
export DATA_PATH=your_downloaded_data_path
|
||||
export HTTP_PROXY=your_http_proxy
|
||||
export HTTPS_PROXY=your_https_proxy
|
||||
|
||||
docker run -itd \
|
||||
--net=host \
|
||||
--name=bigdl-llm-fintune-qlora-cpu \
|
||||
-e http_proxy=${HTTP_PROXY} \
|
||||
-e https_proxy=${HTTPS_PROXY} \
|
||||
-v $BASE_MODE_PATH:/model \
|
||||
-v $DATA_PATH:/data/english_quotes \
|
||||
intelanalytics/bigdl-llm-finetune-qlora-cpu:2.4.0-SNAPSHOT
|
||||
```
|
||||
|
||||
The download and mount of base model and data to a docker container demonstrates a standard fine-tuning process. You can skip this step for a quick start, and in this way, the fine-tuning codes will automatically download the needed files:
|
||||
|
||||
```bash
|
||||
export HTTP_PROXY=your_http_proxy
|
||||
export HTTPS_PROXY=your_https_proxy
|
||||
|
||||
docker run -itd \
|
||||
--net=host \
|
||||
--name=bigdl-llm-fintune-qlora-cpu \
|
||||
-e http_proxy=${HTTP_PROXY} \
|
||||
-e https_proxy=${HTTPS_PROXY} \
|
||||
intelanalytics/bigdl-llm-finetune-qlora-cpu:2.4.0-SNAPSHOT
|
||||
```
|
||||
|
||||
However, we do recommend you to handle them manually, because the automatical download can be blocked by Internet access and Huggingface authentication etc. according to different environment, and the manual method allows you to fine-tune in a custom way (with different base model and dataset).
|
||||
|
||||
### 3. Start Fine-Tuning
|
||||
|
||||
Enter the running container:
|
||||
|
||||
```bash
|
||||
docker exec -it bigdl-llm-fintune-qlora-cpu bash
|
||||
```
|
||||
|
||||
Then, start QLoRA fine-tuning:
|
||||
If the machine memory is not enough, you can try to set `use_gradient_checkpointing=True`.
|
||||
|
||||
And remember to use `bigdl-llm-init` before you start finetuning, which can accelerate the job.
|
||||
```bash
|
||||
source bigdl-llm-init -t
|
||||
bash start-qlora-finetuning-on-cpu.sh
|
||||
```
|
||||
|
||||
After minutes, it is expected to get results like:
|
||||
|
||||
```bash
|
||||
{'loss': 2.256, 'learning_rate': 0.0002, 'epoch': 0.03}
|
||||
{'loss': 1.8869, 'learning_rate': 0.00017777777777777779, 'epoch': 0.06}
|
||||
{'loss': 1.5334, 'learning_rate': 0.00015555555555555556, 'epoch': 0.1}
|
||||
{'loss': 1.4975, 'learning_rate': 0.00013333333333333334, 'epoch': 0.13}
|
||||
{'loss': 1.3245, 'learning_rate': 0.00011111111111111112, 'epoch': 0.16}
|
||||
{'loss': 1.2622, 'learning_rate': 8.888888888888889e-05, 'epoch': 0.19}
|
||||
{'loss': 1.3944, 'learning_rate': 6.666666666666667e-05, 'epoch': 0.22}
|
||||
{'loss': 1.2481, 'learning_rate': 4.4444444444444447e-05, 'epoch': 0.26}
|
||||
{'loss': 1.3442, 'learning_rate': 2.2222222222222223e-05, 'epoch': 0.29}
|
||||
{'loss': 1.3256, 'learning_rate': 0.0, 'epoch': 0.32}
|
||||
{'train_runtime': xxx, 'train_samples_per_second': xxx, 'train_steps_per_second': xxx, 'train_loss': 1.5072882556915284, 'epoch': 0.32}
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████| 200/200 [xx:xx<xx:xx, xxxs/it]
|
||||
TrainOutput(global_step=200, training_loss=1.5072882556915284, metrics={'train_runtime': xxx, 'train_samples_per_second': xxx, 'train_steps_per_second': xxx, 'train_loss': 1.5072882556915284, 'epoch': 0.32})
|
||||
```
|
||||
|
||||
### 4. Merge the adapter into the original model
|
||||
Using the [export_merged_model.py](https://github.com/intel-analytics/BigDL/blob/main/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py) to merge.
|
||||
```
|
||||
python ./export_merged_model.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --adapter_path ./outputs/checkpoint-200 --output_path ./outputs/checkpoint-200-merged
|
||||
```
|
||||
|
||||
Then you can use `./outputs/checkpoint-200-merged` as a normal huggingface transformer model to do inference.
|
||||
|
||||
### 5. Use BigDL-LLM to verify the fine-tuning effect
|
||||
Train more steps and try input sentence like `['quote'] -> [?]` to verify. For example, using `“QLoRA fine-tuning using BigDL-LLM 4bit optimizations on Intel CPU is Efficient and convenient” ->: ` to inference.
|
||||
BigDL-LLM llama2 example [link](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/HF-Transformers-AutoModels/Model/llama2). Update the `LLAMA2_PROMPT_FORMAT = "{prompt}"`.
|
||||
```bash
|
||||
python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt "“QLoRA fine-tuning using BigDL-LLM 4bit optimizations on Intel CPU is Efficient and convenient” ->:" --n-predict 20
|
||||
```
|
||||
|
||||
#### Sample Output
|
||||
Base_model output
|
||||
```log
|
||||
Inference time: xxx s
|
||||
-------------------- Prompt --------------------
|
||||
“QLoRA fine-tuning using BigDL-LLM 4bit optimizations on Intel CPU is Efficient and convenient” ->:
|
||||
-------------------- Output --------------------
|
||||
“QLoRA fine-tuning using BigDL-LLM 4bit optimizations on Intel CPU is Efficient and convenient” ->: 💻 Fine-tuning a language model on a powerful device like an Intel CPU
|
||||
```
|
||||
Merged_model output
|
||||
```log
|
||||
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
|
||||
Inference time: xxx s
|
||||
-------------------- Prompt --------------------
|
||||
“QLoRA fine-tuning using BigDL-LLM 4bit optimizations on Intel CPU is Efficient and convenient” ->:
|
||||
-------------------- Output --------------------
|
||||
“QLoRA fine-tuning using BigDL-LLM 4bit optimizations on Intel CPU is Efficient and convenient” ->: ['bigdl'] ['deep-learning'] ['distributed-computing'] ['intel'] ['optimization'] ['training'] ['training-speed']
|
||||
```
|
||||
87
docker/llm/finetune/qlora/cpu/docker/qlora_finetuning_cpu.py
Normal file
87
docker/llm/finetune/qlora/cpu/docker/qlora_finetuning_cpu.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
import os
|
||||
|
||||
import transformers
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
from peft import LoraConfig
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from datasets import load_dataset
|
||||
import argparse
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
|
||||
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-hf",
|
||||
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
|
||||
', or the path to the huggingface checkpoint folder')
|
||||
parser.add_argument('--dataset', type=str, default="Abirate/english_quotes")
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.repo_id_or_model_path
|
||||
dataset_path = args.dataset
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
data = load_dataset(dataset_path)
|
||||
def merge(row):
|
||||
row['prediction'] = row['quote'] + ' ->: ' + str(row['tags'])
|
||||
return row
|
||||
data = data.map(lambda samples: tokenizer(samples["prediction"]), batched=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
load_in_low_bit="sym_int4",
|
||||
optimize_model=False,
|
||||
torch_dtype=torch.float16,
|
||||
modules_to_not_convert=["lm_head"],)
|
||||
model = model.to('cpu')
|
||||
# model.gradient_checkpointing_enable()
|
||||
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
|
||||
model.enable_input_require_grads()
|
||||
config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=32,
|
||||
target_modules=["q_proj", "k_proj", "v_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
model = get_peft_model(model, config)
|
||||
tokenizer.pad_token_id = 0
|
||||
tokenizer.padding_side = "left"
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=data["train"],
|
||||
args=transformers.TrainingArguments(
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps= 1,
|
||||
warmup_steps=20,
|
||||
max_steps=200,
|
||||
learning_rate=2e-4,
|
||||
save_steps=100,
|
||||
bf16=True,
|
||||
logging_steps=20,
|
||||
output_dir="outputs",
|
||||
optim="adamw_hf", # paged_adamw_8bit is not supported yet
|
||||
# gradient_checkpointing=True, # can further reduce memory but slower
|
||||
),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||
)
|
||||
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
|
||||
result = trainer.train()
|
||||
print(result)
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
set -x
|
||||
export USE_XETLA=OFF
|
||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
if [ -d "./model" ];
|
||||
then
|
||||
MODEL_PARAM="--repo-id-or-model-path ./model" # otherwise, default to download from HF repo
|
||||
fi
|
||||
|
||||
if [ -d "./data/english_quotes" ];
|
||||
then
|
||||
DATA_PARAM="--dataset ./data/english_quotes" # otherwise, default to download from HF dataset
|
||||
fi
|
||||
|
||||
python qlora_finetuning_cpu.py $MODEL_PARAM $DATA_PARAM
|
||||
|
||||
Loading…
Reference in a new issue