[LLM] code and document for distributed qlora (#9585)

* [LLM] code and document for distributed qlora

* doc

* refine for gradient checkpoint

* refine

* Update alpaca_qlora_finetuning_cpu.py

* Update alpaca_qlora_finetuning_cpu.py

* Update alpaca_qlora_finetuning_cpu.py

* add link in doc
This commit is contained in:
Heyang Sun 2023-12-06 09:23:17 +08:00 committed by GitHub
parent d154b38bf9
commit 4e70e33934
8 changed files with 41 additions and 19 deletions

View file

@ -32,7 +32,7 @@ RUN mkdir -p /bigdl/data && mkdir -p /bigdl/model && \
pip install intel_extension_for_pytorch==2.0.100 && \
pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable && \
# install huggingface dependencies
pip install datasets transformers==4.34.0 && \
pip install datasets https://files.pythonhosted.org/packages/9a/06/e4ec2a321e57c03b7e9345d709d554a52c33760e5015fdff0919d9459af0/transformers-4.35.0-py3-none-any.whl && \
pip install fire peft==0.5.0 && \
pip install accelerate==0.23.0 && \
# install basic dependencies

View file

@ -9,6 +9,10 @@ then
sed "s/:1/ /g" /etc/mpi/hostfile > /home/mpiuser/hostfile
sleep 10 # wait for worker pods to be ready
export ACCELERATE_USE_CPU=True
if [ "$ENABLE_GRADIENT_CHECKPOINT" = "true" ]
then
GRADIENT_CHECKPOINT_PARAM="--gradient_checkpointing"
fi
mpirun \
-n $WORLD_SIZE \
-ppn 1 \
@ -24,7 +28,8 @@ then
--data_path "/bigdl/data" \
--output_dir "/home/mpiuser/finetuned_model" \
--batch_size 128 \
--micro_batch_size $MICRO_BATCH_SIZE > /home/mpiuser/launcher.log 2>&1
--micro_batch_size $MICRO_BATCH_SIZE \
$GRADIENT_CHECKPOINT_PARAM > /home/mpiuser/launcher.log 2>&1
exit_status=$?
if [ $exit_status -ne 0 ];
then

View file

@ -1,4 +1,4 @@
## Run NF4&BF16-quantized QLoRA Finetuning on Kubernetes with OneCCL
## Run Distributed QLoRA Fine-Tuning on Kubernetes with OneCCL
![image](https://github.com/intel-analytics/BigDL/assets/60865256/825f47d9-c864-4f39-a331-adb1e3cb528e)

View file

@ -37,6 +37,8 @@ spec:
value: "bigdl-qlora-finetuning-job-worker-0.bigdl-qlora-finetuning-job-worker"
- name: DATA_SUB_PATH
value: "{{ .Values.dataSubPath }}"
- name: ENABLE_GRADIENT_CHECKPOINT
value: "{{ .Values.enableGradientCheckpoint }}"
- name: http_proxy
value: "{{ .Values.httpProxy }}"
- name: https_proxy
@ -85,6 +87,8 @@ spec:
value: "42679"
- name: MASTER_ADDR
value: "bigdl-qlora-finetuning-job-worker-0.bigdl-qlora-finetuning-job-worker"
- name: ENABLE_GRADIENT_CHECKPOINT
value: "{{ .Values.enableGradientCheckpoint }}"
- name: http_proxy
value: "{{ .Values.httpProxy }}"
- name: https_proxy

View file

@ -1,6 +1,7 @@
imageName: intelanalytics/bigdl-llm-finetune-qlora-cpu:2.5.0-SNAPSHOT
trainerNum: 2
microBatchSize: 8
enableGradientCheckpoint: false # true will save more memory but increase latency
nfsServerIp: your_nfs_server_ip
nfsPath: a_nfs_shared_folder_path_on_the_server
dataSubPath: alpaca_data_cleaned_archive.json # a subpath of the data file under nfs directory

View file

@ -7,7 +7,7 @@ This example demonstrates how to finetune a llama2-7b model using Big-LLM 4bit o
1. Single node with single socket: [simple example](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/QLoRA-FineTuning#example-finetune-llama2-7b-using-qlora)
or [alpaca example](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora)
2. [Single node with multiple sockets](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora#guide-to-finetuning-qlora-on-one-node-with-multiple-sockets)
3. multiple nodes with multiple sockets
3. [multiple nodes with multiple sockets](https://github.com/intel-analytics/BigDL/blob/main/docker/llm/finetune/qlora/cpu/kubernetes/README.md)
## Example: Finetune llama2-7b using QLoRA

View file

@ -126,4 +126,11 @@ need to modify the [tokenization_baichuan.py](https://huggingface.co/baichuan-in
from transformers import AutoTokenizer # noqa: F402
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(base_model,trust_remote_code=True)
```
```
### 4. Finetuning in docker and multiple nodes (k8s)
If you want to run multi-process fine-tuning, or do not want to manually install the above dependencies, we provide a docker solution to quickly start a one-container finetuning. Please refer to [here](https://github.com/intel-analytics/BigDL/tree/main/docker/llm/finetune/qlora/cpu/docker#fine-tune-llm-with-bigdl-llm-container).
Moreover, for users with multiple CPU server resources e.g. Xeon series like SPR and ICX, we give a k8s distributed solution, where machines and processor sockets are allowed to collaborate by one click easily. Please refer to [here](https://github.com/intel-analytics/BigDL/blob/main/docker/llm/finetune/qlora/cpu/kubernetes/README.md) for how to run QLoRA on k8s.

View file

@ -61,14 +61,6 @@ def get_int_from_env(env_keys, default):
return val
return default
local_rank = get_int_from_env(["LOCAL_RANK","MPI_LOCALRANKID"], "0")
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
port = get_int_from_env(["MASTER_PORT"], 29500)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(local_rank)
os.environ["MASTER_PORT"] = str(port)
def train(
# model/data params
base_model: str = "meta-llama/Llama-2-7b-hf", # the only required argument, default to be "meta-llama/Llama-2-7b-hf"
@ -134,6 +126,7 @@ def train(
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
f"gradient_checkpointing: {gradient_checkpointing}\n"
)
assert (
base_model
@ -143,7 +136,21 @@ def train(
prompter = Prompter(prompt_template_name)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
if os.environ.get("LOCAL_POD_NAME", "") != "": # K8S dist
pmi_world_size = int(os.environ.get('PMI_SIZE', -1))
if pmi_world_size > 0:
os.environ['WORLD_SIZE'] = str(pmi_world_size)
world_size = 1 if pmi_world_size == 0 else pmi_world_size
else: # Standalone (centralized or multi-process)
local_rank = get_int_from_env(["LOCAL_RANK","MPI_LOCALRANKID"], "0")
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
port = get_int_from_env(["MASTER_PORT"], 29500)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(local_rank)
os.environ["MASTER_PORT"] = str(port)
print(f"world_size: {world_size}")
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
@ -176,7 +183,6 @@ def train(
load_in_low_bit="sym_int4", # not support "nf4"
optimize_model=False,
torch_dtype=torch.bfloat16,
# device_map=device_map,
modules_to_not_convert=["lm_head"],
)
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
@ -190,7 +196,7 @@ def train(
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference
print(model)
def tokenize(prompt, add_eos_token=True):
@ -322,11 +328,9 @@ def train(
report_to="wandb" if use_wandb else None,
run_name=wandb_run_name if use_wandb else None,
gradient_checkpointing=gradient_checkpointing,
gradient_checkpointing_kwargs={"use_reentrant": False} if gradient_checkpointing else None,
ddp_backend="ccl" if ddp else None,
)
if ddp:
from accelerate.state import PartialState
args.distributed_state = PartialState(cpu=True, backend=args.ddp_backend)
trainer = transformers.Trainer(
model=model,
@ -351,3 +355,4 @@ def train(
if __name__ == "__main__":
fire.Fire(train)