BF16 Lora Finetuning on K8S with OneCCL and Intel MPI (#8775)

* BF16 Lora Finetuning on K8S with OneCCL and Intel MPI

* Update README.md

* format

* refine

* Update README.md

* refine

* Update README.md

* increase nfs volume size to improve IO performance

* fix bugs

* Update README.md

* Update README.md

* fix permission

* move output destination

* Update README.md

* fix wrong base model name in doc

* fix output path in entrypoint

* add a permission-precreated output dir

* format

* move output logs to a persistent storage
This commit is contained in:
Heyang Sun 2023-08-31 14:56:23 +08:00 committed by GitHub
parent de6c6bb17f
commit b1ac8dc1bc
12 changed files with 654 additions and 0 deletions

View file

@ -0,0 +1,55 @@
## Run BF16-Optimized Lora Finetuning on Kubernetes with OneCCL
[Alpaca Lora](https://github.com/tloen/alpaca-lora/tree/main) uses [low-rank adaption](https://arxiv.org/pdf/2106.09685.pdf) to speed up the finetuning process of base model [Llama 7b](https://huggingface.co/decapoda-research/llama-7b-hf), and tries to reproduce the standard Alpaca, a general finetuned LLM. This is on top of Hugging Face transformers with Pytorch backend, which natively requires a number of expensive GPU resources and takes significant time.
By constract, BigDL here provides a CPU optimization to accelerate the lora finetuning of Llama 7b, in the power of mixed-precision and distributed training. Detailedly, [Intel OneCCL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/oneccl.html), an available Hugging Face backend, is able to speed up the Pytorch computation with BF16 datatype on CPUs, as well as parallel processing on Kubernetes enabled by [Intel MPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/mpi-library.html).
The architecture is illustrated in the following:
![image](https://github.com/Uxito-Ada/BigDL/assets/60865256/139cf9be-10e6-48df-bc84-8872457e83dd)
As above, BigDL implements its MPI training build on [Kubeflow MPI operator](https://github.com/kubeflow/mpi-operator/tree/master), which encapsulates the deployment as MPIJob CRD, and assists users to handle the construction of a MPI worker cluster on Kubernetes, such as public key distribution, SSH connection, and log collection.
Now, let's go to deploy a Lora finetuning to create a LLM from Llama 7b.
**Note: Please make sure you have already have an available Kubernetes infrastructure and NFS shared storage, and install [Helm CLI](https://helm.sh/docs/helm/helm_install/) for Kubernetes job submission.**
### 1. Install Kubeflow MPI Operator
Follow [here](https://github.com/kubeflow/mpi-operator/tree/master#installation) to install a Kubeflow MPI operator in your Kubernetes, which will listen and receive the following MPIJob request at backend.
### 2. Download Image, Base Model and Finetuning Data
Follow [here](https://github.com/intel-analytics/BigDL/tree/main/docker/llm/finetune/lora/docker#prepare-bigdl-image-for-lora-finetuning) to prepare BigDL Lora Finetuning image in your cluster.
As finetuning is from a base model, first download [Llama 7b hf model from the public download site of Hugging Face](https://huggingface.co/decapoda-research/llama-7b-hf/tree/main). Then, download [cleaned alpaca data](https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json), which contains all kinds of general knowledge and has already been cleaned. Next, move the downloaded files to a shared directory on your NFS server. In addition, make an empty directory under the same destination to save the finetuned model output later.
### 3. Deploy through Helm Chart
You are allowed to edit and experiment with different parameters in `./kubernetes/values.yaml` to improve finetuning performance and accuracy. For example, you can adjust `trainerNum` and `cpuPerPod` according to node and CPU core numbers in your cluster to make full use of these resources, and different `microBatchSize` result in different training speed and loss (here note that `microBatchSize`×`trainerNum` should not more than 128, as it is the batch size).
** Note: `dataSubPath`, `modelSubPath` and `outputPath` need to have the same names as files under the NFS directory in step 2. **
After preparing parameters in `./kubernetes/values.yaml`, submit the job as beflow:
```bash
cd ./kubernetes
helm install bigdl-lora-finetuning .
```
### 4. Check Deployment
```bash
kubectl get all -n bigdl-lora-finetuning # you will see launcher and worker pods running
```
### 5. Check Finetuning Process
After deploying successfully, you can find a launcher pod, and then go inside this pod and check the logs collected from all workers.
```bash
kubectl get all -n bigdl-lora-finetuning # you will see a launcher pod
kubectl exec -it <launcher_pod_name> bash -n bigdl-ppml-finetuning # enter launcher pod
cat launcher.log # display logs collected from other workers
```
From the log, you can see whether finetuning process has been invoked successfully in all MPI worker pods, and a progress bar with finetuning speed and estimated time will be showed after some data preprocessing steps (this may take quiet a while).

View file

@ -0,0 +1,58 @@
ARG HTTP_PROXY
ARG HTTPS_PROXY
FROM mpioperator/intel as builder
ARG HTTP_PROXY
ARG HTTPS_PROXY
ADD ./requirements.txt /ppml/requirements.txt
RUN mkdir /ppml/data && mkdir /ppml/model && mkdir /ppml/output && \
# install pytorch 2.0.1
export http_proxy=$HTTP_PROXY && \
export https_proxy=$HTTPS_PROXY && \
apt-get update && \
apt-get install -y python3-pip python3.9-dev python3-wheel && \
pip3 install --upgrade pip && \
pip install torch==2.0.1 && \
# install ipex and oneccl
pip install intel_extension_for_pytorch==2.0.100 && \
pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable && \
# install transformers etc.
cd /ppml && \
apt-get update && \
apt-get install -y git && \
git clone https://github.com/huggingface/transformers.git && \
cd transformers && \
git reset --hard 057e1d74733f52817dc05b673a340b4e3ebea08c && \
pip install . && \
pip install -r /ppml/requirements.txt && \
# install python
env DEBIAN_FRONTEND=noninteractive apt-get update && \
apt install software-properties-common -y && \
add-apt-repository ppa:deadsnakes/ppa -y && \
apt-get install -y python3.9 && \
rm /usr/bin/python3 && \
ln -s /usr/bin/python3.9 /usr/bin/python3 && \
ln -s /usr/bin/python3 /usr/bin/python && \
apt-get install -y python3-pip python3.9-dev python3-wheel && \
pip install --upgrade pip && \
pip install --no-cache requests argparse cryptography==3.3.2 urllib3 && \
pip install --upgrade requests && \
pip install setuptools==58.4.0 && \
# Install OpenSSH for MPI to communicate between containers
apt-get install -y --no-install-recommends openssh-client openssh-server && \
mkdir -p /var/run/sshd && \
# Allow OpenSSH to talk to containers without asking for confirmation
# by disabling StrictHostKeyChecking.
# mpi-operator mounts the .ssh folder from a Secret. For that to work, we need
# to disable UserKnownHostsFile to avoid write permissions.
# Disabling StrictModes avoids directory and files read permission checks.
sed -i 's/[ #]\(.*StrictHostKeyChecking \).*/ \1no/g' /etc/ssh/ssh_config && \
echo " UserKnownHostsFile /dev/null" >> /etc/ssh/ssh_config && \
sed -i 's/#\(StrictModes \).*/\1no/g' /etc/ssh/sshd_config
ADD ./bigdl-lora-finetuing-entrypoint.sh /ppml/bigdl-lora-finetuing-entrypoint.sh
ADD ./lora_finetune.py /ppml/lora_finetune.py
RUN chown -R mpiuser /ppml
USER mpiuser

View file

@ -0,0 +1,20 @@
## Prepare BigDL image for Lora Finetuning
You can download directly from Dockerhub like:
```bash
docker pull intelanalytics/bigdl-lora-finetuning:2.4.0-SNAPSHOT
```
Or build the image from source:
```bash
export HTTP_PROXY=your_http_proxy
export HTTPS_PROXY=your_https_proxy
docker build \
--build-arg HTTP_PROXY=${HTTP_PROXY} \
--build-arg HTTPS_PROXY=${HTTPS_PROXY} \
-t intelanalytics/bigdl-lora-finetuning:2.4.0-SNAPSHOT \
-f ./Dockerfile .
```

View file

@ -0,0 +1,47 @@
#!/bin/bash
set -x
source /opt/intel/oneapi/setvars.sh
export CCL_WORKER_COUNT=$WORLD_SIZE
export CCL_WORKER_AFFINITY=auto
if [ "$WORKER_ROLE" = "launcher" ]
then
sed "s/:1/ /g" /etc/mpi/hostfile > /home/mpiuser/hostfile
export DATA_PATH="/ppml/data/$DATA_SUB_PATH"
export SAVE_PATH="/ppml/output"
sleep 10
mpirun \
-n $WORLD_SIZE \
-ppn 1 \
-f /home/mpiuser/hostfile \
-iface eth0 \
-genv OMP_NUM_THREADS=$OMP_NUM_THREADS \
-genv KMP_AFFINITY="granularity=fine,none" \
-genv KMP_BLOCKTIME=1 \
-genv TF_ENABLE_ONEDNN_OPTS=1 \
python /ppml/lora_finetune.py \
--base_model '/ppml/model/' \
--data_path "$DATA_PATH" \
--output_dir "$SAVE_PATH/finetuned_model" \
--micro_batch_size $MICRO_BATCH_SIZE \
--bf16 > $SAVE_PATH/launcher.log 2>&1
exit_status=$?
if [ $exit_status -ne 0 ];
then
cat launcher.log
exit $exit_status
else
while true
do
echo "[INFO] Successfully finished training"
sleep 900
done
fi
elif [ "$WORKER_ROLE" = "trainer" ]
then
export LOCAL_RANK=$(cut -d "-" -f6 <<< "$LOCAL_POD_NAME")
export PMI_SIZE=$WORLD_SIZE
export PMI_RANK=$LOCAL_RANK
/usr/sbin/sshd -De -f /home/mpiuser/.sshd_config
fi

View file

@ -0,0 +1,316 @@
import os
import sys
from typing import List
import time
import fire
import torch
from torch.utils.data import DataLoader
import transformers
from datasets import load_dataset
"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""
# Catch when user should re-install transformers library
# assert (
# "LlamaTokenizer" in transformers._import_structure["models.llama"]
# ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" # noqa: E501
from peft import ( # noqa: E402
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
from transformers import BitsAndBytesConfig
def train(
# model/data params
base_model: str = "", # the only required argument
data_path: str = "./alpaca_data_cleaned.json",
output_dir: str = "./lora-alpaca",
# training hyperparams
batch_size: int = 128,
micro_batch_size: int = 4,
num_epochs: int = 3,
learning_rate: float = 3e-4,
cutoff_len: int = 256,
val_set_size: int = 2000,
# lora hyperparams
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = [
"q_proj",
"v_proj",
],
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
use_ipex: bool = False,
bf16: bool = False,
no_cuda: bool=True,
xpu_backend: str="ccl"
):
print(
f"Training Alpaca-LoRA model with params:\n"
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"lora_r: {lora_r}\n"
f"lora_alpha: {lora_alpha}\n"
f"lora_dropout: {lora_dropout}\n"
f"lora_target_modules: {lora_target_modules}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint}\n"
f"use_ipex: {use_ipex}\n"
f"bf16: {bf16}\n"
)
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
gradient_accumulation_steps = batch_size // micro_batch_size
device_map = "auto"
pmi_world_size = int(os.environ.get('PMI_SIZE', -1))
if pmi_world_size > 0:
os.environ['WORLD_SIZE'] = str(pmi_world_size)
else:
os.environ['WORLD_SIZE'] = str(os.environ.get('WORLD_SIZE', 1))
world_size = int(os.environ.get("WORLD_SIZE", 1))
print(f"world_size: {world_size}!!")
ddp = world_size != 1
local_rank = 0
if ddp:
os.environ['RANK'] = str(os.environ.get('PMI_RANK', 0))
os.environ['LOCAL_RANK'] = str(os.environ.get('PMI_RANK', 0))
local_rank = str(os.environ.get('PMI_RANK', 0))
print("PMI_RANK(local_rank): " + local_rank)
gradient_accumulation_steps = gradient_accumulation_steps // world_size
# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or \
("WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
os.environ['WANDB_PROJECT'] = wandb_project
if len(wandb_watch) > 0:
os.environ['WANDB_WATCH'] = wandb_watch
if len(wandb_log_model) > 0:
os.environ['WANDB_LOG_MODEL'] = wandb_log_model
model = LlamaForCausalLM.from_pretrained(
base_model,
low_cpu_mem_usage=True
)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference
def tokenize(prompt, add_eos_token=True):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = generate_prompt(data_point)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
user_prompt = generate_prompt({**data_point, "output": ""})
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
model = prepare_model_for_int8_training(model)
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
if data_path.endswith(".json"): # todo: support jsonl
data = load_dataset("json", data_files=data_path)
else:
data = load_dataset(data_path)
if resume_from_checkpoint:
# Check the available weights and load them
checkpoint_name = os.path.join(
resume_from_checkpoint, "pytorch_model.bin"
) # Full checkpoint
if not os.path.exists(checkpoint_name):
checkpoint_name = os.path.join(
resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
resume_from_checkpoint = (
False # So the trainer won't try loading its state
)
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
model = set_peft_model_state_dict(model, adapters_weights)
else:
print(f"Checkpoint {checkpoint_name} not found")
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
if val_set_size > 0:
print("[INFO] spliting and shuffling dataset...")
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=42
)
print("[INFO] shuffling and tokenizing train data...")
train_data = (
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
)
print("[INFO] shuffling and tokenizing test data...")
val_data = (
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = None
print("[INFO] begining the training of transformers...")
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=100,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
bf16=bf16,
logging_steps=10,
optim="adamw_torch",
evaluation_strategy="epoch",
save_strategy="steps",
local_rank=local_rank,
output_dir=output_dir,
save_total_limit=3,
ddp_find_unused_parameters=False,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else None,
run_name=wandb_run_name if use_wandb else None,
xpu_backend=xpu_backend,
no_cuda=no_cuda
)
print(
f"[INFO] Process rank: {args.local_rank}, device: {args.device}"
+ f"distributed training: {args.parallel_mode.value == 'distributed'}"
)
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=args,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(
self, old_state_dict()
)
).__get__(model, type(model))
start = time.time()
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
end = time.time()
print("training time is: ", end - start)
if int(os.environ.get("PMI_RANK", -1)) == 0:
model.save_pretrained(output_dir)
elif int(os.environ.get("PMI_RANK", -1)) == -1:
model.save_pretrained(output_dir)
print(
"\n If there's a warning about missing keys above, please disregard :)"
)
def generate_prompt(data_point):
# sorry about the formatting disaster gotta move fast
if data_point["input"]:
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
### Instruction:
{data_point["instruction"]}
### Input:
{data_point["input"]}
### Response:
{data_point["output"]}"""
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501
### Instruction:
{data_point["instruction"]}
### Response:
{data_point["output"]}"""
if __name__ == "__main__":
fire.Fire(train)

View file

@ -0,0 +1,13 @@
accelerate
appdirs
bitsandbytes
black
black[jupyter]
datasets
fire
peft==0.2.0
#git+https://github.com/huggingface/peft.git
#git+https://github.com/huggingface/transformers.git
gradio
sentencepiece
scipy

View file

@ -0,0 +1,6 @@
apiVersion: v2
name: trusted-fintune-service
description: A Helm chart for BigDL PPML Trusted BigData Service on Kubernetes
type: application
version: 1.1.27
appVersion: "1.16.0"

View file

@ -0,0 +1,97 @@
apiVersion: kubeflow.org/v2beta1
kind: MPIJob
metadata:
name: bigdl-lora-finetuning-job
namespace: bigdl-lora-finetuning
spec:
slotsPerWorker: 1
runPolicy:
cleanPodPolicy: Running
sshAuthMountPath: /home/mpiuser/.ssh
mpiImplementation: Intel
mpiReplicaSpecs:
Launcher:
replicas: 1
template:
spec:
volumes:
- name: nfs-storage
persistentVolumeClaim:
claimName: nfs-pvc
containers:
- image: {{ .Values.imageName }}
name: bigdl-ppml-finetuning-launcher
securityContext:
runAsUser: 1000
command: ['sh' , '-c', 'bash /ppml/bigdl-lora-finetuing-entrypoint.sh']
env:
- name: WORKER_ROLE
value: "launcher"
- name: WORLD_SIZE
value: "{{ .Values.trainerNum }}"
- name: MICRO_BATCH_SIZE
value: "{{ .Values.microBatchSize }}"
- name: MASTER_PORT
value: "42679"
- name: MASTER_ADDR
value: "bigdl-lora-finetuning-job-worker-0.bigdl-lora-finetuning-job-worker"
- name: DATA_SUB_PATH
value: "{{ .Values.dataSubPath }}"
- name: OMP_NUM_THREADS
value: "{{ .Values.ompNumThreads }}"
- name: LOCAL_POD_NAME
valueFrom:
fieldRef:
fieldPath: metadata.name
volumeMounts:
- name: nfs-storage
subPath: {{ .Values.modelSubPath }}
mountPath: /ppml/model
- name: nfs-storage
subPath: {{ .Values.dataSubPath }}
mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
- name: nfs-storage
subPath: {{ .Values.outputSubPath }}
mountPath: "/ppml/output"
Worker:
replicas: {{ .Values.trainerNum }}
template:
spec:
containers:
- image: {{ .Values.imageName }}
name: bigdl-ppml-finetuning-worker
securityContext:
runAsUser: 1000
command: ['sh' , '-c', 'bash /ppml/bigdl-lora-finetuing-entrypoint.sh']
env:
- name: WORKER_ROLE
value: "trainer"
- name: WORLD_SIZE
value: "{{ .Values.trainerNum }}"
- name: MICRO_BATCH_SIZE
value: "{{ .Values.microBatchSize }}"
- name: MASTER_PORT
value: "42679"
- name: MASTER_ADDR
value: "bigdl-lora-finetuning-job-worker-0.bigdl-lora-finetuning-job-worker"
- name: LOCAL_POD_NAME
valueFrom:
fieldRef:
fieldPath: metadata.name
volumeMounts:
- name: nfs-storage
subPath: {{ .Values.modelSubPath }}
mountPath: /ppml/model
- name: nfs-storage
subPath: {{ .Values.dataSubPath }}
mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
- name: nfs-storage
subPath: {{ .Values.outputSubPath }}
mountPath: "/ppml/output"
resources:
requests:
cpu: {{ .Values.cpuPerPod }}
volumes:
- name: nfs-storage
persistentVolumeClaim:
claimName: nfs-pvc

View file

@ -0,0 +1,4 @@
apiVersion: v1
kind: Namespace
metadata:
name: bigdl-lora-finetuning

View file

@ -0,0 +1,15 @@
apiVersion: v1
kind: PersistentVolume
metadata:
name: nfs-pv-bigdl-lora-finetuning
namespace: bigdl-lora-finetuning
spec:
capacity:
storage: 15Gi
accessModes:
- ReadWriteOnce
persistentVolumeReclaimPolicy: Retain
storageClassName: nfs
nfs:
path: {{ .Values.nfsPath }}
server: {{ .Values.nfsServerIp }}

View file

@ -0,0 +1,12 @@
kind: PersistentVolumeClaim
apiVersion: v1
metadata:
name: nfs-pvc
namespace: bigdl-lora-finetuning
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 10Gi
storageClassName: nfs

View file

@ -0,0 +1,11 @@
imageName: intelanalytics/bigdl-lora-finetuning:2.4.0-SNAPSHOT
trainerNum: 8
microBatchSize: 8
TEEMode: tdx # tdx or native
nfsServerIp: your_nfs_server_ip
nfsPath: a_nfs_shared_folder_path_on_the_server
dataSubPath: alpaca_data_cleaned_archive.json # a subpath of the data file under nfs directory
modelSubPath: llama-7b-hf # a subpath of the model file (dir) under nfs directory
outputSubPath: output # a subpath of the empty directory under the nfs directory to save finetuned model, for example, if you make an empty dir named 'output' at the nfsPath, the value should be 'output'
ompNumThreads: 14
cpuPerPod: 42