BF16 Lora Finetuning on K8S with OneCCL and Intel MPI (#8775)
* BF16 Lora Finetuning on K8S with OneCCL and Intel MPI * Update README.md * format * refine * Update README.md * refine * Update README.md * increase nfs volume size to improve IO performance * fix bugs * Update README.md * Update README.md * fix permission * move output destination * Update README.md * fix wrong base model name in doc * fix output path in entrypoint * add a permission-precreated output dir * format * move output logs to a persistent storage
This commit is contained in:
parent
de6c6bb17f
commit
b1ac8dc1bc
12 changed files with 654 additions and 0 deletions
55
docker/llm/finetune/lora/README.md
Normal file
55
docker/llm/finetune/lora/README.md
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
## Run BF16-Optimized Lora Finetuning on Kubernetes with OneCCL
|
||||
|
||||
[Alpaca Lora](https://github.com/tloen/alpaca-lora/tree/main) uses [low-rank adaption](https://arxiv.org/pdf/2106.09685.pdf) to speed up the finetuning process of base model [Llama 7b](https://huggingface.co/decapoda-research/llama-7b-hf), and tries to reproduce the standard Alpaca, a general finetuned LLM. This is on top of Hugging Face transformers with Pytorch backend, which natively requires a number of expensive GPU resources and takes significant time.
|
||||
|
||||
By constract, BigDL here provides a CPU optimization to accelerate the lora finetuning of Llama 7b, in the power of mixed-precision and distributed training. Detailedly, [Intel OneCCL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/oneccl.html), an available Hugging Face backend, is able to speed up the Pytorch computation with BF16 datatype on CPUs, as well as parallel processing on Kubernetes enabled by [Intel MPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/mpi-library.html).
|
||||
|
||||
The architecture is illustrated in the following:
|
||||
|
||||

|
||||
|
||||
As above, BigDL implements its MPI training build on [Kubeflow MPI operator](https://github.com/kubeflow/mpi-operator/tree/master), which encapsulates the deployment as MPIJob CRD, and assists users to handle the construction of a MPI worker cluster on Kubernetes, such as public key distribution, SSH connection, and log collection.
|
||||
|
||||
Now, let's go to deploy a Lora finetuning to create a LLM from Llama 7b.
|
||||
|
||||
**Note: Please make sure you have already have an available Kubernetes infrastructure and NFS shared storage, and install [Helm CLI](https://helm.sh/docs/helm/helm_install/) for Kubernetes job submission.**
|
||||
|
||||
### 1. Install Kubeflow MPI Operator
|
||||
|
||||
Follow [here](https://github.com/kubeflow/mpi-operator/tree/master#installation) to install a Kubeflow MPI operator in your Kubernetes, which will listen and receive the following MPIJob request at backend.
|
||||
|
||||
### 2. Download Image, Base Model and Finetuning Data
|
||||
|
||||
Follow [here](https://github.com/intel-analytics/BigDL/tree/main/docker/llm/finetune/lora/docker#prepare-bigdl-image-for-lora-finetuning) to prepare BigDL Lora Finetuning image in your cluster.
|
||||
|
||||
As finetuning is from a base model, first download [Llama 7b hf model from the public download site of Hugging Face](https://huggingface.co/decapoda-research/llama-7b-hf/tree/main). Then, download [cleaned alpaca data](https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json), which contains all kinds of general knowledge and has already been cleaned. Next, move the downloaded files to a shared directory on your NFS server. In addition, make an empty directory under the same destination to save the finetuned model output later.
|
||||
|
||||
### 3. Deploy through Helm Chart
|
||||
|
||||
You are allowed to edit and experiment with different parameters in `./kubernetes/values.yaml` to improve finetuning performance and accuracy. For example, you can adjust `trainerNum` and `cpuPerPod` according to node and CPU core numbers in your cluster to make full use of these resources, and different `microBatchSize` result in different training speed and loss (here note that `microBatchSize`×`trainerNum` should not more than 128, as it is the batch size).
|
||||
|
||||
** Note: `dataSubPath`, `modelSubPath` and `outputPath` need to have the same names as files under the NFS directory in step 2. **
|
||||
|
||||
After preparing parameters in `./kubernetes/values.yaml`, submit the job as beflow:
|
||||
|
||||
```bash
|
||||
cd ./kubernetes
|
||||
helm install bigdl-lora-finetuning .
|
||||
```
|
||||
|
||||
### 4. Check Deployment
|
||||
```bash
|
||||
kubectl get all -n bigdl-lora-finetuning # you will see launcher and worker pods running
|
||||
```
|
||||
|
||||
### 5. Check Finetuning Process
|
||||
|
||||
After deploying successfully, you can find a launcher pod, and then go inside this pod and check the logs collected from all workers.
|
||||
|
||||
```bash
|
||||
kubectl get all -n bigdl-lora-finetuning # you will see a launcher pod
|
||||
kubectl exec -it <launcher_pod_name> bash -n bigdl-ppml-finetuning # enter launcher pod
|
||||
cat launcher.log # display logs collected from other workers
|
||||
```
|
||||
|
||||
From the log, you can see whether finetuning process has been invoked successfully in all MPI worker pods, and a progress bar with finetuning speed and estimated time will be showed after some data preprocessing steps (this may take quiet a while).
|
||||
58
docker/llm/finetune/lora/docker/Dockerfile
Normal file
58
docker/llm/finetune/lora/docker/Dockerfile
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
ARG HTTP_PROXY
|
||||
ARG HTTPS_PROXY
|
||||
|
||||
FROM mpioperator/intel as builder
|
||||
|
||||
ARG HTTP_PROXY
|
||||
ARG HTTPS_PROXY
|
||||
ADD ./requirements.txt /ppml/requirements.txt
|
||||
|
||||
RUN mkdir /ppml/data && mkdir /ppml/model && mkdir /ppml/output && \
|
||||
# install pytorch 2.0.1
|
||||
export http_proxy=$HTTP_PROXY && \
|
||||
export https_proxy=$HTTPS_PROXY && \
|
||||
apt-get update && \
|
||||
apt-get install -y python3-pip python3.9-dev python3-wheel && \
|
||||
pip3 install --upgrade pip && \
|
||||
pip install torch==2.0.1 && \
|
||||
# install ipex and oneccl
|
||||
pip install intel_extension_for_pytorch==2.0.100 && \
|
||||
pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable && \
|
||||
# install transformers etc.
|
||||
cd /ppml && \
|
||||
apt-get update && \
|
||||
apt-get install -y git && \
|
||||
git clone https://github.com/huggingface/transformers.git && \
|
||||
cd transformers && \
|
||||
git reset --hard 057e1d74733f52817dc05b673a340b4e3ebea08c && \
|
||||
pip install . && \
|
||||
pip install -r /ppml/requirements.txt && \
|
||||
# install python
|
||||
env DEBIAN_FRONTEND=noninteractive apt-get update && \
|
||||
apt install software-properties-common -y && \
|
||||
add-apt-repository ppa:deadsnakes/ppa -y && \
|
||||
apt-get install -y python3.9 && \
|
||||
rm /usr/bin/python3 && \
|
||||
ln -s /usr/bin/python3.9 /usr/bin/python3 && \
|
||||
ln -s /usr/bin/python3 /usr/bin/python && \
|
||||
apt-get install -y python3-pip python3.9-dev python3-wheel && \
|
||||
pip install --upgrade pip && \
|
||||
pip install --no-cache requests argparse cryptography==3.3.2 urllib3 && \
|
||||
pip install --upgrade requests && \
|
||||
pip install setuptools==58.4.0 && \
|
||||
# Install OpenSSH for MPI to communicate between containers
|
||||
apt-get install -y --no-install-recommends openssh-client openssh-server && \
|
||||
mkdir -p /var/run/sshd && \
|
||||
# Allow OpenSSH to talk to containers without asking for confirmation
|
||||
# by disabling StrictHostKeyChecking.
|
||||
# mpi-operator mounts the .ssh folder from a Secret. For that to work, we need
|
||||
# to disable UserKnownHostsFile to avoid write permissions.
|
||||
# Disabling StrictModes avoids directory and files read permission checks.
|
||||
sed -i 's/[ #]\(.*StrictHostKeyChecking \).*/ \1no/g' /etc/ssh/ssh_config && \
|
||||
echo " UserKnownHostsFile /dev/null" >> /etc/ssh/ssh_config && \
|
||||
sed -i 's/#\(StrictModes \).*/\1no/g' /etc/ssh/sshd_config
|
||||
|
||||
ADD ./bigdl-lora-finetuing-entrypoint.sh /ppml/bigdl-lora-finetuing-entrypoint.sh
|
||||
ADD ./lora_finetune.py /ppml/lora_finetune.py
|
||||
RUN chown -R mpiuser /ppml
|
||||
USER mpiuser
|
||||
20
docker/llm/finetune/lora/docker/README.md
Normal file
20
docker/llm/finetune/lora/docker/README.md
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
## Prepare BigDL image for Lora Finetuning
|
||||
|
||||
You can download directly from Dockerhub like:
|
||||
|
||||
```bash
|
||||
docker pull intelanalytics/bigdl-lora-finetuning:2.4.0-SNAPSHOT
|
||||
```
|
||||
|
||||
Or build the image from source:
|
||||
|
||||
```bash
|
||||
export HTTP_PROXY=your_http_proxy
|
||||
export HTTPS_PROXY=your_https_proxy
|
||||
|
||||
docker build \
|
||||
--build-arg HTTP_PROXY=${HTTP_PROXY} \
|
||||
--build-arg HTTPS_PROXY=${HTTPS_PROXY} \
|
||||
-t intelanalytics/bigdl-lora-finetuning:2.4.0-SNAPSHOT \
|
||||
-f ./Dockerfile .
|
||||
```
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
#!/bin/bash
|
||||
set -x
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
export CCL_WORKER_COUNT=$WORLD_SIZE
|
||||
export CCL_WORKER_AFFINITY=auto
|
||||
|
||||
if [ "$WORKER_ROLE" = "launcher" ]
|
||||
then
|
||||
sed "s/:1/ /g" /etc/mpi/hostfile > /home/mpiuser/hostfile
|
||||
export DATA_PATH="/ppml/data/$DATA_SUB_PATH"
|
||||
export SAVE_PATH="/ppml/output"
|
||||
sleep 10
|
||||
mpirun \
|
||||
-n $WORLD_SIZE \
|
||||
-ppn 1 \
|
||||
-f /home/mpiuser/hostfile \
|
||||
-iface eth0 \
|
||||
-genv OMP_NUM_THREADS=$OMP_NUM_THREADS \
|
||||
-genv KMP_AFFINITY="granularity=fine,none" \
|
||||
-genv KMP_BLOCKTIME=1 \
|
||||
-genv TF_ENABLE_ONEDNN_OPTS=1 \
|
||||
python /ppml/lora_finetune.py \
|
||||
--base_model '/ppml/model/' \
|
||||
--data_path "$DATA_PATH" \
|
||||
--output_dir "$SAVE_PATH/finetuned_model" \
|
||||
--micro_batch_size $MICRO_BATCH_SIZE \
|
||||
--bf16 > $SAVE_PATH/launcher.log 2>&1
|
||||
exit_status=$?
|
||||
if [ $exit_status -ne 0 ];
|
||||
then
|
||||
cat launcher.log
|
||||
exit $exit_status
|
||||
else
|
||||
while true
|
||||
do
|
||||
echo "[INFO] Successfully finished training"
|
||||
sleep 900
|
||||
done
|
||||
fi
|
||||
elif [ "$WORKER_ROLE" = "trainer" ]
|
||||
then
|
||||
export LOCAL_RANK=$(cut -d "-" -f6 <<< "$LOCAL_POD_NAME")
|
||||
export PMI_SIZE=$WORLD_SIZE
|
||||
export PMI_RANK=$LOCAL_RANK
|
||||
/usr/sbin/sshd -De -f /home/mpiuser/.sshd_config
|
||||
fi
|
||||
|
||||
316
docker/llm/finetune/lora/docker/lora_finetune.py
Normal file
316
docker/llm/finetune/lora/docker/lora_finetune.py
Normal file
|
|
@ -0,0 +1,316 @@
|
|||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
import time
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
"""
|
||||
Unused imports:
|
||||
import torch.nn as nn
|
||||
import bitsandbytes as bnb
|
||||
"""
|
||||
|
||||
# Catch when user should re-install transformers library
|
||||
# assert (
|
||||
# "LlamaTokenizer" in transformers._import_structure["models.llama"]
|
||||
# ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" # noqa: E501
|
||||
|
||||
from peft import ( # noqa: E402
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
get_peft_model_state_dict,
|
||||
prepare_model_for_int8_training,
|
||||
set_peft_model_state_dict,
|
||||
)
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
|
||||
def train(
|
||||
# model/data params
|
||||
base_model: str = "", # the only required argument
|
||||
data_path: str = "./alpaca_data_cleaned.json",
|
||||
output_dir: str = "./lora-alpaca",
|
||||
# training hyperparams
|
||||
batch_size: int = 128,
|
||||
micro_batch_size: int = 4,
|
||||
num_epochs: int = 3,
|
||||
learning_rate: float = 3e-4,
|
||||
cutoff_len: int = 256,
|
||||
val_set_size: int = 2000,
|
||||
# lora hyperparams
|
||||
lora_r: int = 8,
|
||||
lora_alpha: int = 16,
|
||||
lora_dropout: float = 0.05,
|
||||
lora_target_modules: List[str] = [
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
],
|
||||
# llm hyperparams
|
||||
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
||||
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
||||
# wandb params
|
||||
wandb_project: str = "",
|
||||
wandb_run_name: str = "",
|
||||
wandb_watch: str = "", # options: false | gradients | all
|
||||
wandb_log_model: str = "", # options: false | true
|
||||
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
||||
use_ipex: bool = False,
|
||||
bf16: bool = False,
|
||||
no_cuda: bool=True,
|
||||
xpu_backend: str="ccl"
|
||||
):
|
||||
print(
|
||||
f"Training Alpaca-LoRA model with params:\n"
|
||||
f"base_model: {base_model}\n"
|
||||
f"data_path: {data_path}\n"
|
||||
f"output_dir: {output_dir}\n"
|
||||
f"batch_size: {batch_size}\n"
|
||||
f"micro_batch_size: {micro_batch_size}\n"
|
||||
f"num_epochs: {num_epochs}\n"
|
||||
f"learning_rate: {learning_rate}\n"
|
||||
f"cutoff_len: {cutoff_len}\n"
|
||||
f"val_set_size: {val_set_size}\n"
|
||||
f"lora_r: {lora_r}\n"
|
||||
f"lora_alpha: {lora_alpha}\n"
|
||||
f"lora_dropout: {lora_dropout}\n"
|
||||
f"lora_target_modules: {lora_target_modules}\n"
|
||||
f"train_on_inputs: {train_on_inputs}\n"
|
||||
f"group_by_length: {group_by_length}\n"
|
||||
f"wandb_project: {wandb_project}\n"
|
||||
f"wandb_run_name: {wandb_run_name}\n"
|
||||
f"wandb_watch: {wandb_watch}\n"
|
||||
f"wandb_log_model: {wandb_log_model}\n"
|
||||
f"resume_from_checkpoint: {resume_from_checkpoint}\n"
|
||||
f"use_ipex: {use_ipex}\n"
|
||||
f"bf16: {bf16}\n"
|
||||
)
|
||||
assert (
|
||||
base_model
|
||||
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
||||
gradient_accumulation_steps = batch_size // micro_batch_size
|
||||
|
||||
device_map = "auto"
|
||||
pmi_world_size = int(os.environ.get('PMI_SIZE', -1))
|
||||
if pmi_world_size > 0:
|
||||
os.environ['WORLD_SIZE'] = str(pmi_world_size)
|
||||
else:
|
||||
os.environ['WORLD_SIZE'] = str(os.environ.get('WORLD_SIZE', 1))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
print(f"world_size: {world_size}!!")
|
||||
ddp = world_size != 1
|
||||
local_rank = 0
|
||||
if ddp:
|
||||
os.environ['RANK'] = str(os.environ.get('PMI_RANK', 0))
|
||||
os.environ['LOCAL_RANK'] = str(os.environ.get('PMI_RANK', 0))
|
||||
local_rank = str(os.environ.get('PMI_RANK', 0))
|
||||
print("PMI_RANK(local_rank): " + local_rank)
|
||||
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
||||
|
||||
# Check if parameter passed or if set within environ
|
||||
use_wandb = len(wandb_project) > 0 or \
|
||||
("WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0)
|
||||
# Only overwrite environ if wandb param passed
|
||||
if len(wandb_project) > 0:
|
||||
os.environ['WANDB_PROJECT'] = wandb_project
|
||||
if len(wandb_watch) > 0:
|
||||
os.environ['WANDB_WATCH'] = wandb_watch
|
||||
if len(wandb_log_model) > 0:
|
||||
os.environ['WANDB_LOG_MODEL'] = wandb_log_model
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||
|
||||
tokenizer.pad_token_id = (
|
||||
0 # unk. we want this to be different from the eos token
|
||||
)
|
||||
tokenizer.padding_side = "left" # Allow batched inference
|
||||
|
||||
def tokenize(prompt, add_eos_token=True):
|
||||
# there's probably a way to do this with the tokenizer settings
|
||||
# but again, gotta move fast
|
||||
result = tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=cutoff_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if (
|
||||
result["input_ids"][-1] != tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < cutoff_len
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
|
||||
return result
|
||||
|
||||
def generate_and_tokenize_prompt(data_point):
|
||||
full_prompt = generate_prompt(data_point)
|
||||
tokenized_full_prompt = tokenize(full_prompt)
|
||||
if not train_on_inputs:
|
||||
user_prompt = generate_prompt({**data_point, "output": ""})
|
||||
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
|
||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||
|
||||
tokenized_full_prompt["labels"] = [
|
||||
-100
|
||||
] * user_prompt_len + tokenized_full_prompt["labels"][
|
||||
user_prompt_len:
|
||||
] # could be sped up, probably
|
||||
return tokenized_full_prompt
|
||||
|
||||
model = prepare_model_for_int8_training(model)
|
||||
|
||||
config = LoraConfig(
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
if data_path.endswith(".json"): # todo: support jsonl
|
||||
data = load_dataset("json", data_files=data_path)
|
||||
else:
|
||||
data = load_dataset(data_path)
|
||||
|
||||
if resume_from_checkpoint:
|
||||
# Check the available weights and load them
|
||||
checkpoint_name = os.path.join(
|
||||
resume_from_checkpoint, "pytorch_model.bin"
|
||||
) # Full checkpoint
|
||||
if not os.path.exists(checkpoint_name):
|
||||
checkpoint_name = os.path.join(
|
||||
resume_from_checkpoint, "adapter_model.bin"
|
||||
) # only LoRA model - LoRA config above has to fit
|
||||
resume_from_checkpoint = (
|
||||
False # So the trainer won't try loading its state
|
||||
)
|
||||
# The two files above have a different name depending on how they were saved, but are actually the same.
|
||||
if os.path.exists(checkpoint_name):
|
||||
print(f"Restarting from {checkpoint_name}")
|
||||
adapters_weights = torch.load(checkpoint_name)
|
||||
model = set_peft_model_state_dict(model, adapters_weights)
|
||||
else:
|
||||
print(f"Checkpoint {checkpoint_name} not found")
|
||||
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
||||
|
||||
if val_set_size > 0:
|
||||
print("[INFO] spliting and shuffling dataset...")
|
||||
train_val = data["train"].train_test_split(
|
||||
test_size=val_set_size, shuffle=True, seed=42
|
||||
)
|
||||
print("[INFO] shuffling and tokenizing train data...")
|
||||
train_data = (
|
||||
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||
)
|
||||
print("[INFO] shuffling and tokenizing test data...")
|
||||
val_data = (
|
||||
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
||||
)
|
||||
else:
|
||||
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||
val_data = None
|
||||
print("[INFO] begining the training of transformers...")
|
||||
|
||||
args=transformers.TrainingArguments(
|
||||
per_device_train_batch_size=micro_batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
warmup_steps=100,
|
||||
num_train_epochs=num_epochs,
|
||||
learning_rate=learning_rate,
|
||||
bf16=bf16,
|
||||
logging_steps=10,
|
||||
optim="adamw_torch",
|
||||
evaluation_strategy="epoch",
|
||||
save_strategy="steps",
|
||||
local_rank=local_rank,
|
||||
output_dir=output_dir,
|
||||
save_total_limit=3,
|
||||
ddp_find_unused_parameters=False,
|
||||
group_by_length=group_by_length,
|
||||
report_to="wandb" if use_wandb else None,
|
||||
run_name=wandb_run_name if use_wandb else None,
|
||||
xpu_backend=xpu_backend,
|
||||
no_cuda=no_cuda
|
||||
)
|
||||
|
||||
print(
|
||||
f"[INFO] Process rank: {args.local_rank}, device: {args.device}"
|
||||
+ f"distributed training: {args.parallel_mode.value == 'distributed'}"
|
||||
)
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=train_data,
|
||||
eval_dataset=val_data,
|
||||
args=args,
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||
),
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
old_state_dict = model.state_dict
|
||||
model.state_dict = (
|
||||
lambda self, *_, **__: get_peft_model_state_dict(
|
||||
self, old_state_dict()
|
||||
)
|
||||
).__get__(model, type(model))
|
||||
|
||||
start = time.time()
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
end = time.time()
|
||||
print("training time is: ", end - start)
|
||||
|
||||
if int(os.environ.get("PMI_RANK", -1)) == 0:
|
||||
model.save_pretrained(output_dir)
|
||||
elif int(os.environ.get("PMI_RANK", -1)) == -1:
|
||||
model.save_pretrained(output_dir)
|
||||
|
||||
print(
|
||||
"\n If there's a warning about missing keys above, please disregard :)"
|
||||
)
|
||||
|
||||
|
||||
def generate_prompt(data_point):
|
||||
# sorry about the formatting disaster gotta move fast
|
||||
if data_point["input"]:
|
||||
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
|
||||
|
||||
### Instruction:
|
||||
{data_point["instruction"]}
|
||||
|
||||
### Input:
|
||||
{data_point["input"]}
|
||||
|
||||
### Response:
|
||||
{data_point["output"]}"""
|
||||
else:
|
||||
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501
|
||||
|
||||
### Instruction:
|
||||
{data_point["instruction"]}
|
||||
|
||||
### Response:
|
||||
{data_point["output"]}"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(train)
|
||||
13
docker/llm/finetune/lora/docker/requirements.txt
Normal file
13
docker/llm/finetune/lora/docker/requirements.txt
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
accelerate
|
||||
appdirs
|
||||
bitsandbytes
|
||||
black
|
||||
black[jupyter]
|
||||
datasets
|
||||
fire
|
||||
peft==0.2.0
|
||||
#git+https://github.com/huggingface/peft.git
|
||||
#git+https://github.com/huggingface/transformers.git
|
||||
gradio
|
||||
sentencepiece
|
||||
scipy
|
||||
6
docker/llm/finetune/lora/kubernetes/Chart.yaml
Normal file
6
docker/llm/finetune/lora/kubernetes/Chart.yaml
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
apiVersion: v2
|
||||
name: trusted-fintune-service
|
||||
description: A Helm chart for BigDL PPML Trusted BigData Service on Kubernetes
|
||||
type: application
|
||||
version: 1.1.27
|
||||
appVersion: "1.16.0"
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
apiVersion: kubeflow.org/v2beta1
|
||||
kind: MPIJob
|
||||
metadata:
|
||||
name: bigdl-lora-finetuning-job
|
||||
namespace: bigdl-lora-finetuning
|
||||
spec:
|
||||
slotsPerWorker: 1
|
||||
runPolicy:
|
||||
cleanPodPolicy: Running
|
||||
sshAuthMountPath: /home/mpiuser/.ssh
|
||||
mpiImplementation: Intel
|
||||
mpiReplicaSpecs:
|
||||
Launcher:
|
||||
replicas: 1
|
||||
template:
|
||||
spec:
|
||||
volumes:
|
||||
- name: nfs-storage
|
||||
persistentVolumeClaim:
|
||||
claimName: nfs-pvc
|
||||
containers:
|
||||
- image: {{ .Values.imageName }}
|
||||
name: bigdl-ppml-finetuning-launcher
|
||||
securityContext:
|
||||
runAsUser: 1000
|
||||
command: ['sh' , '-c', 'bash /ppml/bigdl-lora-finetuing-entrypoint.sh']
|
||||
env:
|
||||
- name: WORKER_ROLE
|
||||
value: "launcher"
|
||||
- name: WORLD_SIZE
|
||||
value: "{{ .Values.trainerNum }}"
|
||||
- name: MICRO_BATCH_SIZE
|
||||
value: "{{ .Values.microBatchSize }}"
|
||||
- name: MASTER_PORT
|
||||
value: "42679"
|
||||
- name: MASTER_ADDR
|
||||
value: "bigdl-lora-finetuning-job-worker-0.bigdl-lora-finetuning-job-worker"
|
||||
- name: DATA_SUB_PATH
|
||||
value: "{{ .Values.dataSubPath }}"
|
||||
- name: OMP_NUM_THREADS
|
||||
value: "{{ .Values.ompNumThreads }}"
|
||||
- name: LOCAL_POD_NAME
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: metadata.name
|
||||
volumeMounts:
|
||||
- name: nfs-storage
|
||||
subPath: {{ .Values.modelSubPath }}
|
||||
mountPath: /ppml/model
|
||||
- name: nfs-storage
|
||||
subPath: {{ .Values.dataSubPath }}
|
||||
mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
|
||||
- name: nfs-storage
|
||||
subPath: {{ .Values.outputSubPath }}
|
||||
mountPath: "/ppml/output"
|
||||
Worker:
|
||||
replicas: {{ .Values.trainerNum }}
|
||||
template:
|
||||
spec:
|
||||
containers:
|
||||
- image: {{ .Values.imageName }}
|
||||
name: bigdl-ppml-finetuning-worker
|
||||
securityContext:
|
||||
runAsUser: 1000
|
||||
command: ['sh' , '-c', 'bash /ppml/bigdl-lora-finetuing-entrypoint.sh']
|
||||
env:
|
||||
- name: WORKER_ROLE
|
||||
value: "trainer"
|
||||
- name: WORLD_SIZE
|
||||
value: "{{ .Values.trainerNum }}"
|
||||
- name: MICRO_BATCH_SIZE
|
||||
value: "{{ .Values.microBatchSize }}"
|
||||
- name: MASTER_PORT
|
||||
value: "42679"
|
||||
- name: MASTER_ADDR
|
||||
value: "bigdl-lora-finetuning-job-worker-0.bigdl-lora-finetuning-job-worker"
|
||||
- name: LOCAL_POD_NAME
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: metadata.name
|
||||
volumeMounts:
|
||||
- name: nfs-storage
|
||||
subPath: {{ .Values.modelSubPath }}
|
||||
mountPath: /ppml/model
|
||||
- name: nfs-storage
|
||||
subPath: {{ .Values.dataSubPath }}
|
||||
mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
|
||||
- name: nfs-storage
|
||||
subPath: {{ .Values.outputSubPath }}
|
||||
mountPath: "/ppml/output"
|
||||
resources:
|
||||
requests:
|
||||
cpu: {{ .Values.cpuPerPod }}
|
||||
volumes:
|
||||
- name: nfs-storage
|
||||
persistentVolumeClaim:
|
||||
claimName: nfs-pvc
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
apiVersion: v1
|
||||
kind: Namespace
|
||||
metadata:
|
||||
name: bigdl-lora-finetuning
|
||||
15
docker/llm/finetune/lora/kubernetes/templates/nfs-pv.yaml
Normal file
15
docker/llm/finetune/lora/kubernetes/templates/nfs-pv.yaml
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
apiVersion: v1
|
||||
kind: PersistentVolume
|
||||
metadata:
|
||||
name: nfs-pv-bigdl-lora-finetuning
|
||||
namespace: bigdl-lora-finetuning
|
||||
spec:
|
||||
capacity:
|
||||
storage: 15Gi
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
persistentVolumeReclaimPolicy: Retain
|
||||
storageClassName: nfs
|
||||
nfs:
|
||||
path: {{ .Values.nfsPath }}
|
||||
server: {{ .Values.nfsServerIp }}
|
||||
12
docker/llm/finetune/lora/kubernetes/templates/nfs-pvc.yaml
Normal file
12
docker/llm/finetune/lora/kubernetes/templates/nfs-pvc.yaml
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
kind: PersistentVolumeClaim
|
||||
apiVersion: v1
|
||||
metadata:
|
||||
name: nfs-pvc
|
||||
namespace: bigdl-lora-finetuning
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
resources:
|
||||
requests:
|
||||
storage: 10Gi
|
||||
storageClassName: nfs
|
||||
11
docker/llm/finetune/lora/kubernetes/values.yaml
Normal file
11
docker/llm/finetune/lora/kubernetes/values.yaml
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
imageName: intelanalytics/bigdl-lora-finetuning:2.4.0-SNAPSHOT
|
||||
trainerNum: 8
|
||||
microBatchSize: 8
|
||||
TEEMode: tdx # tdx or native
|
||||
nfsServerIp: your_nfs_server_ip
|
||||
nfsPath: a_nfs_shared_folder_path_on_the_server
|
||||
dataSubPath: alpaca_data_cleaned_archive.json # a subpath of the data file under nfs directory
|
||||
modelSubPath: llama-7b-hf # a subpath of the model file (dir) under nfs directory
|
||||
outputSubPath: output # a subpath of the empty directory under the nfs directory to save finetuned model, for example, if you make an empty dir named 'output' at the nfsPath, the value should be 'output'
|
||||
ompNumThreads: 14
|
||||
cpuPerPod: 42
|
||||
Loading…
Reference in a new issue