[LLM] code and document for distributed qlora (#9585)
* [LLM] code and document for distributed qlora * doc * refine for gradient checkpoint * refine * Update alpaca_qlora_finetuning_cpu.py * Update alpaca_qlora_finetuning_cpu.py * Update alpaca_qlora_finetuning_cpu.py * add link in doc
This commit is contained in:
parent
d154b38bf9
commit
4e70e33934
8 changed files with 41 additions and 19 deletions
|
|
@ -32,7 +32,7 @@ RUN mkdir -p /bigdl/data && mkdir -p /bigdl/model && \
|
|||
pip install intel_extension_for_pytorch==2.0.100 && \
|
||||
pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable && \
|
||||
# install huggingface dependencies
|
||||
pip install datasets transformers==4.34.0 && \
|
||||
pip install datasets https://files.pythonhosted.org/packages/9a/06/e4ec2a321e57c03b7e9345d709d554a52c33760e5015fdff0919d9459af0/transformers-4.35.0-py3-none-any.whl && \
|
||||
pip install fire peft==0.5.0 && \
|
||||
pip install accelerate==0.23.0 && \
|
||||
# install basic dependencies
|
||||
|
|
|
|||
|
|
@ -9,6 +9,10 @@ then
|
|||
sed "s/:1/ /g" /etc/mpi/hostfile > /home/mpiuser/hostfile
|
||||
sleep 10 # wait for worker pods to be ready
|
||||
export ACCELERATE_USE_CPU=True
|
||||
if [ "$ENABLE_GRADIENT_CHECKPOINT" = "true" ]
|
||||
then
|
||||
GRADIENT_CHECKPOINT_PARAM="--gradient_checkpointing"
|
||||
fi
|
||||
mpirun \
|
||||
-n $WORLD_SIZE \
|
||||
-ppn 1 \
|
||||
|
|
@ -24,7 +28,8 @@ then
|
|||
--data_path "/bigdl/data" \
|
||||
--output_dir "/home/mpiuser/finetuned_model" \
|
||||
--batch_size 128 \
|
||||
--micro_batch_size $MICRO_BATCH_SIZE > /home/mpiuser/launcher.log 2>&1
|
||||
--micro_batch_size $MICRO_BATCH_SIZE \
|
||||
$GRADIENT_CHECKPOINT_PARAM > /home/mpiuser/launcher.log 2>&1
|
||||
exit_status=$?
|
||||
if [ $exit_status -ne 0 ];
|
||||
then
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
## Run NF4&BF16-quantized QLoRA Finetuning on Kubernetes with OneCCL
|
||||
## Run Distributed QLoRA Fine-Tuning on Kubernetes with OneCCL
|
||||
|
||||

|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ spec:
|
|||
value: "bigdl-qlora-finetuning-job-worker-0.bigdl-qlora-finetuning-job-worker"
|
||||
- name: DATA_SUB_PATH
|
||||
value: "{{ .Values.dataSubPath }}"
|
||||
- name: ENABLE_GRADIENT_CHECKPOINT
|
||||
value: "{{ .Values.enableGradientCheckpoint }}"
|
||||
- name: http_proxy
|
||||
value: "{{ .Values.httpProxy }}"
|
||||
- name: https_proxy
|
||||
|
|
@ -85,6 +87,8 @@ spec:
|
|||
value: "42679"
|
||||
- name: MASTER_ADDR
|
||||
value: "bigdl-qlora-finetuning-job-worker-0.bigdl-qlora-finetuning-job-worker"
|
||||
- name: ENABLE_GRADIENT_CHECKPOINT
|
||||
value: "{{ .Values.enableGradientCheckpoint }}"
|
||||
- name: http_proxy
|
||||
value: "{{ .Values.httpProxy }}"
|
||||
- name: https_proxy
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
imageName: intelanalytics/bigdl-llm-finetune-qlora-cpu:2.5.0-SNAPSHOT
|
||||
trainerNum: 2
|
||||
microBatchSize: 8
|
||||
enableGradientCheckpoint: false # true will save more memory but increase latency
|
||||
nfsServerIp: your_nfs_server_ip
|
||||
nfsPath: a_nfs_shared_folder_path_on_the_server
|
||||
dataSubPath: alpaca_data_cleaned_archive.json # a subpath of the data file under nfs directory
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ This example demonstrates how to finetune a llama2-7b model using Big-LLM 4bit o
|
|||
1. Single node with single socket: [simple example](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/QLoRA-FineTuning#example-finetune-llama2-7b-using-qlora)
|
||||
or [alpaca example](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora)
|
||||
2. [Single node with multiple sockets](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora#guide-to-finetuning-qlora-on-one-node-with-multiple-sockets)
|
||||
3. multiple nodes with multiple sockets
|
||||
3. [multiple nodes with multiple sockets](https://github.com/intel-analytics/BigDL/blob/main/docker/llm/finetune/qlora/cpu/kubernetes/README.md)
|
||||
|
||||
## Example: Finetune llama2-7b using QLoRA
|
||||
|
||||
|
|
|
|||
|
|
@ -126,4 +126,11 @@ need to modify the [tokenization_baichuan.py](https://huggingface.co/baichuan-in
|
|||
from transformers import AutoTokenizer # noqa: F402
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
||||
base_model = AutoModelForCausalLM.from_pretrained(base_model,trust_remote_code=True)
|
||||
```
|
||||
```
|
||||
|
||||
|
||||
### 4. Finetuning in docker and multiple nodes (k8s)
|
||||
|
||||
If you want to run multi-process fine-tuning, or do not want to manually install the above dependencies, we provide a docker solution to quickly start a one-container finetuning. Please refer to [here](https://github.com/intel-analytics/BigDL/tree/main/docker/llm/finetune/qlora/cpu/docker#fine-tune-llm-with-bigdl-llm-container).
|
||||
|
||||
Moreover, for users with multiple CPU server resources e.g. Xeon series like SPR and ICX, we give a k8s distributed solution, where machines and processor sockets are allowed to collaborate by one click easily. Please refer to [here](https://github.com/intel-analytics/BigDL/blob/main/docker/llm/finetune/qlora/cpu/kubernetes/README.md) for how to run QLoRA on k8s.
|
||||
|
|
|
|||
|
|
@ -61,14 +61,6 @@ def get_int_from_env(env_keys, default):
|
|||
return val
|
||||
return default
|
||||
|
||||
local_rank = get_int_from_env(["LOCAL_RANK","MPI_LOCALRANKID"], "0")
|
||||
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
|
||||
port = get_int_from_env(["MASTER_PORT"], 29500)
|
||||
os.environ["LOCAL_RANK"] = str(local_rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["RANK"] = str(local_rank)
|
||||
os.environ["MASTER_PORT"] = str(port)
|
||||
|
||||
def train(
|
||||
# model/data params
|
||||
base_model: str = "meta-llama/Llama-2-7b-hf", # the only required argument, default to be "meta-llama/Llama-2-7b-hf"
|
||||
|
|
@ -134,6 +126,7 @@ def train(
|
|||
f"wandb_log_model: {wandb_log_model}\n"
|
||||
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
|
||||
f"prompt template: {prompt_template_name}\n"
|
||||
f"gradient_checkpointing: {gradient_checkpointing}\n"
|
||||
)
|
||||
assert (
|
||||
base_model
|
||||
|
|
@ -143,7 +136,21 @@ def train(
|
|||
prompter = Prompter(prompt_template_name)
|
||||
|
||||
device_map = "auto"
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
if os.environ.get("LOCAL_POD_NAME", "") != "": # K8S dist
|
||||
pmi_world_size = int(os.environ.get('PMI_SIZE', -1))
|
||||
if pmi_world_size > 0:
|
||||
os.environ['WORLD_SIZE'] = str(pmi_world_size)
|
||||
world_size = 1 if pmi_world_size == 0 else pmi_world_size
|
||||
else: # Standalone (centralized or multi-process)
|
||||
local_rank = get_int_from_env(["LOCAL_RANK","MPI_LOCALRANKID"], "0")
|
||||
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
|
||||
port = get_int_from_env(["MASTER_PORT"], 29500)
|
||||
os.environ["LOCAL_RANK"] = str(local_rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["RANK"] = str(local_rank)
|
||||
os.environ["MASTER_PORT"] = str(port)
|
||||
|
||||
print(f"world_size: {world_size}")
|
||||
ddp = world_size != 1
|
||||
if ddp:
|
||||
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
||||
|
|
@ -176,7 +183,6 @@ def train(
|
|||
load_in_low_bit="sym_int4", # not support "nf4"
|
||||
optimize_model=False,
|
||||
torch_dtype=torch.bfloat16,
|
||||
# device_map=device_map,
|
||||
modules_to_not_convert=["lm_head"],
|
||||
)
|
||||
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
|
||||
|
|
@ -190,7 +196,7 @@ def train(
|
|||
0 # unk. we want this to be different from the eos token
|
||||
)
|
||||
tokenizer.padding_side = "left" # Allow batched inference
|
||||
|
||||
|
||||
print(model)
|
||||
|
||||
def tokenize(prompt, add_eos_token=True):
|
||||
|
|
@ -322,11 +328,9 @@ def train(
|
|||
report_to="wandb" if use_wandb else None,
|
||||
run_name=wandb_run_name if use_wandb else None,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False} if gradient_checkpointing else None,
|
||||
ddp_backend="ccl" if ddp else None,
|
||||
)
|
||||
if ddp:
|
||||
from accelerate.state import PartialState
|
||||
args.distributed_state = PartialState(cpu=True, backend=args.ddp_backend)
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
|
|
@ -351,3 +355,4 @@ def train(
|
|||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(train)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue