[LLM] code and document for distributed qlora (#9585)
* [LLM] code and document for distributed qlora * doc * refine for gradient checkpoint * refine * Update alpaca_qlora_finetuning_cpu.py * Update alpaca_qlora_finetuning_cpu.py * Update alpaca_qlora_finetuning_cpu.py * add link in doc
This commit is contained in:
		
							parent
							
								
									d154b38bf9
								
							
						
					
					
						commit
						4e70e33934
					
				
					 8 changed files with 41 additions and 19 deletions
				
			
		| 
						 | 
					@ -32,7 +32,7 @@ RUN mkdir -p /bigdl/data && mkdir -p /bigdl/model && \
 | 
				
			||||||
    pip install intel_extension_for_pytorch==2.0.100 && \
 | 
					    pip install intel_extension_for_pytorch==2.0.100 && \
 | 
				
			||||||
    pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable && \
 | 
					    pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable && \
 | 
				
			||||||
# install huggingface dependencies
 | 
					# install huggingface dependencies
 | 
				
			||||||
    pip install datasets transformers==4.34.0 && \
 | 
					    pip install datasets https://files.pythonhosted.org/packages/9a/06/e4ec2a321e57c03b7e9345d709d554a52c33760e5015fdff0919d9459af0/transformers-4.35.0-py3-none-any.whl && \
 | 
				
			||||||
    pip install fire peft==0.5.0 && \
 | 
					    pip install fire peft==0.5.0 && \
 | 
				
			||||||
    pip install accelerate==0.23.0 && \
 | 
					    pip install accelerate==0.23.0 && \
 | 
				
			||||||
# install basic dependencies
 | 
					# install basic dependencies
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -9,6 +9,10 @@ then
 | 
				
			||||||
  sed "s/:1/ /g" /etc/mpi/hostfile > /home/mpiuser/hostfile
 | 
					  sed "s/:1/ /g" /etc/mpi/hostfile > /home/mpiuser/hostfile
 | 
				
			||||||
  sleep 10 # wait for worker pods to be ready
 | 
					  sleep 10 # wait for worker pods to be ready
 | 
				
			||||||
  export ACCELERATE_USE_CPU=True
 | 
					  export ACCELERATE_USE_CPU=True
 | 
				
			||||||
 | 
					  if [ "$ENABLE_GRADIENT_CHECKPOINT" = "true" ]
 | 
				
			||||||
 | 
					  then
 | 
				
			||||||
 | 
					    GRADIENT_CHECKPOINT_PARAM="--gradient_checkpointing"
 | 
				
			||||||
 | 
					  fi
 | 
				
			||||||
  mpirun \
 | 
					  mpirun \
 | 
				
			||||||
    -n $WORLD_SIZE \
 | 
					    -n $WORLD_SIZE \
 | 
				
			||||||
    -ppn 1 \
 | 
					    -ppn 1 \
 | 
				
			||||||
| 
						 | 
					@ -24,7 +28,8 @@ then
 | 
				
			||||||
      --data_path "/bigdl/data" \
 | 
					      --data_path "/bigdl/data" \
 | 
				
			||||||
      --output_dir "/home/mpiuser/finetuned_model" \
 | 
					      --output_dir "/home/mpiuser/finetuned_model" \
 | 
				
			||||||
      --batch_size 128 \
 | 
					      --batch_size 128 \
 | 
				
			||||||
      --micro_batch_size $MICRO_BATCH_SIZE > /home/mpiuser/launcher.log 2>&1
 | 
					      --micro_batch_size $MICRO_BATCH_SIZE \
 | 
				
			||||||
 | 
					      $GRADIENT_CHECKPOINT_PARAM > /home/mpiuser/launcher.log 2>&1
 | 
				
			||||||
  exit_status=$?
 | 
					  exit_status=$?
 | 
				
			||||||
  if [ $exit_status -ne 0 ];
 | 
					  if [ $exit_status -ne 0 ];
 | 
				
			||||||
  then
 | 
					  then
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
## Run NF4&BF16-quantized QLoRA Finetuning on Kubernetes with OneCCL
 | 
					## Run Distributed QLoRA Fine-Tuning on Kubernetes with OneCCL
 | 
				
			||||||
 | 
					
 | 
				
			||||||

 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,6 +37,8 @@ spec:
 | 
				
			||||||
               value: "bigdl-qlora-finetuning-job-worker-0.bigdl-qlora-finetuning-job-worker"
 | 
					               value: "bigdl-qlora-finetuning-job-worker-0.bigdl-qlora-finetuning-job-worker"
 | 
				
			||||||
             - name: DATA_SUB_PATH
 | 
					             - name: DATA_SUB_PATH
 | 
				
			||||||
               value: "{{ .Values.dataSubPath }}"
 | 
					               value: "{{ .Values.dataSubPath }}"
 | 
				
			||||||
 | 
					             - name: ENABLE_GRADIENT_CHECKPOINT
 | 
				
			||||||
 | 
					               value: "{{ .Values.enableGradientCheckpoint }}"
 | 
				
			||||||
             - name: http_proxy
 | 
					             - name: http_proxy
 | 
				
			||||||
               value: "{{ .Values.httpProxy }}"
 | 
					               value: "{{ .Values.httpProxy }}"
 | 
				
			||||||
             - name: https_proxy
 | 
					             - name: https_proxy
 | 
				
			||||||
| 
						 | 
					@ -85,6 +87,8 @@ spec:
 | 
				
			||||||
              value: "42679"
 | 
					              value: "42679"
 | 
				
			||||||
            - name: MASTER_ADDR
 | 
					            - name: MASTER_ADDR
 | 
				
			||||||
              value: "bigdl-qlora-finetuning-job-worker-0.bigdl-qlora-finetuning-job-worker"
 | 
					              value: "bigdl-qlora-finetuning-job-worker-0.bigdl-qlora-finetuning-job-worker"
 | 
				
			||||||
 | 
					            - name: ENABLE_GRADIENT_CHECKPOINT
 | 
				
			||||||
 | 
					              value: "{{ .Values.enableGradientCheckpoint }}"
 | 
				
			||||||
            - name: http_proxy
 | 
					            - name: http_proxy
 | 
				
			||||||
              value: "{{ .Values.httpProxy }}"
 | 
					              value: "{{ .Values.httpProxy }}"
 | 
				
			||||||
            - name: https_proxy
 | 
					            - name: https_proxy
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,7 @@
 | 
				
			||||||
imageName: intelanalytics/bigdl-llm-finetune-qlora-cpu:2.5.0-SNAPSHOT
 | 
					imageName: intelanalytics/bigdl-llm-finetune-qlora-cpu:2.5.0-SNAPSHOT
 | 
				
			||||||
trainerNum: 2
 | 
					trainerNum: 2
 | 
				
			||||||
microBatchSize: 8
 | 
					microBatchSize: 8
 | 
				
			||||||
 | 
					enableGradientCheckpoint: false # true will save more memory but increase latency
 | 
				
			||||||
nfsServerIp: your_nfs_server_ip
 | 
					nfsServerIp: your_nfs_server_ip
 | 
				
			||||||
nfsPath: a_nfs_shared_folder_path_on_the_server
 | 
					nfsPath: a_nfs_shared_folder_path_on_the_server
 | 
				
			||||||
dataSubPath: alpaca_data_cleaned_archive.json # a subpath of the data file under nfs directory
 | 
					dataSubPath: alpaca_data_cleaned_archive.json # a subpath of the data file under nfs directory
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -7,7 +7,7 @@ This example demonstrates how to finetune a llama2-7b model using Big-LLM 4bit o
 | 
				
			||||||
1. Single node with single socket: [simple example](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/QLoRA-FineTuning#example-finetune-llama2-7b-using-qlora)
 | 
					1. Single node with single socket: [simple example](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/QLoRA-FineTuning#example-finetune-llama2-7b-using-qlora)
 | 
				
			||||||
or [alpaca example](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora)
 | 
					or [alpaca example](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora)
 | 
				
			||||||
2. [Single node with multiple sockets](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora#guide-to-finetuning-qlora-on-one-node-with-multiple-sockets)
 | 
					2. [Single node with multiple sockets](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/CPU/QLoRA-FineTuning/alpaca-qlora#guide-to-finetuning-qlora-on-one-node-with-multiple-sockets)
 | 
				
			||||||
3. multiple nodes with multiple sockets
 | 
					3. [multiple nodes with multiple sockets](https://github.com/intel-analytics/BigDL/blob/main/docker/llm/finetune/qlora/cpu/kubernetes/README.md)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Example: Finetune llama2-7b using QLoRA
 | 
					## Example: Finetune llama2-7b using QLoRA
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -127,3 +127,10 @@ from transformers import AutoTokenizer  # noqa: F402
 | 
				
			||||||
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
 | 
					tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
 | 
				
			||||||
base_model = AutoModelForCausalLM.from_pretrained(base_model,trust_remote_code=True)
 | 
					base_model = AutoModelForCausalLM.from_pretrained(base_model,trust_remote_code=True)
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 4. Finetuning in docker and multiple nodes (k8s)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					If you want to run multi-process fine-tuning, or do not want to manually install the above dependencies, we provide a docker solution to quickly start a one-container finetuning. Please refer to [here](https://github.com/intel-analytics/BigDL/tree/main/docker/llm/finetune/qlora/cpu/docker#fine-tune-llm-with-bigdl-llm-container).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Moreover, for users with multiple CPU server resources e.g. Xeon series like SPR and ICX, we give a k8s distributed solution, where machines and processor sockets are allowed to collaborate by one click easily. Please refer to [here](https://github.com/intel-analytics/BigDL/blob/main/docker/llm/finetune/qlora/cpu/kubernetes/README.md) for how to run QLoRA on k8s.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -61,14 +61,6 @@ def get_int_from_env(env_keys, default):
 | 
				
			||||||
            return val
 | 
					            return val
 | 
				
			||||||
    return default
 | 
					    return default
 | 
				
			||||||
 | 
					
 | 
				
			||||||
local_rank = get_int_from_env(["LOCAL_RANK","MPI_LOCALRANKID"], "0")
 | 
					 | 
				
			||||||
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
 | 
					 | 
				
			||||||
port = get_int_from_env(["MASTER_PORT"], 29500)
 | 
					 | 
				
			||||||
os.environ["LOCAL_RANK"] = str(local_rank)
 | 
					 | 
				
			||||||
os.environ["WORLD_SIZE"] = str(world_size)
 | 
					 | 
				
			||||||
os.environ["RANK"] = str(local_rank)
 | 
					 | 
				
			||||||
os.environ["MASTER_PORT"] = str(port)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def train(
 | 
					def train(
 | 
				
			||||||
    # model/data params
 | 
					    # model/data params
 | 
				
			||||||
    base_model: str = "meta-llama/Llama-2-7b-hf",  # the only required argument, default to be "meta-llama/Llama-2-7b-hf"
 | 
					    base_model: str = "meta-llama/Llama-2-7b-hf",  # the only required argument, default to be "meta-llama/Llama-2-7b-hf"
 | 
				
			||||||
| 
						 | 
					@ -134,6 +126,7 @@ def train(
 | 
				
			||||||
            f"wandb_log_model: {wandb_log_model}\n"
 | 
					            f"wandb_log_model: {wandb_log_model}\n"
 | 
				
			||||||
            f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
 | 
					            f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
 | 
				
			||||||
            f"prompt template: {prompt_template_name}\n"
 | 
					            f"prompt template: {prompt_template_name}\n"
 | 
				
			||||||
 | 
					            f"gradient_checkpointing: {gradient_checkpointing}\n"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    assert (
 | 
					    assert (
 | 
				
			||||||
        base_model
 | 
					        base_model
 | 
				
			||||||
| 
						 | 
					@ -143,7 +136,21 @@ def train(
 | 
				
			||||||
    prompter = Prompter(prompt_template_name)
 | 
					    prompter = Prompter(prompt_template_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    device_map = "auto"
 | 
					    device_map = "auto"
 | 
				
			||||||
    world_size = int(os.environ.get("WORLD_SIZE", 1))
 | 
					    if os.environ.get("LOCAL_POD_NAME", "") != "": # K8S dist
 | 
				
			||||||
 | 
					        pmi_world_size = int(os.environ.get('PMI_SIZE', -1))
 | 
				
			||||||
 | 
					        if pmi_world_size > 0:
 | 
				
			||||||
 | 
					            os.environ['WORLD_SIZE'] = str(pmi_world_size)
 | 
				
			||||||
 | 
					        world_size = 1 if pmi_world_size == 0 else pmi_world_size
 | 
				
			||||||
 | 
					    else: # Standalone (centralized or multi-process)
 | 
				
			||||||
 | 
					        local_rank = get_int_from_env(["LOCAL_RANK","MPI_LOCALRANKID"], "0")
 | 
				
			||||||
 | 
					        world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
 | 
				
			||||||
 | 
					        port = get_int_from_env(["MASTER_PORT"], 29500)
 | 
				
			||||||
 | 
					        os.environ["LOCAL_RANK"] = str(local_rank)
 | 
				
			||||||
 | 
					        os.environ["WORLD_SIZE"] = str(world_size)
 | 
				
			||||||
 | 
					        os.environ["RANK"] = str(local_rank)
 | 
				
			||||||
 | 
					        os.environ["MASTER_PORT"] = str(port)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    print(f"world_size: {world_size}")
 | 
				
			||||||
    ddp = world_size != 1
 | 
					    ddp = world_size != 1
 | 
				
			||||||
    if ddp:
 | 
					    if ddp:
 | 
				
			||||||
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
 | 
					        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
 | 
				
			||||||
| 
						 | 
					@ -176,7 +183,6 @@ def train(
 | 
				
			||||||
            load_in_low_bit="sym_int4", # not support "nf4"
 | 
					            load_in_low_bit="sym_int4", # not support "nf4"
 | 
				
			||||||
            optimize_model=False,
 | 
					            optimize_model=False,
 | 
				
			||||||
            torch_dtype=torch.bfloat16,
 | 
					            torch_dtype=torch.bfloat16,
 | 
				
			||||||
            # device_map=device_map,
 | 
					 | 
				
			||||||
            modules_to_not_convert=["lm_head"],
 | 
					            modules_to_not_convert=["lm_head"],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
 | 
					    print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
 | 
				
			||||||
| 
						 | 
					@ -322,11 +328,9 @@ def train(
 | 
				
			||||||
            report_to="wandb" if use_wandb else None,
 | 
					            report_to="wandb" if use_wandb else None,
 | 
				
			||||||
            run_name=wandb_run_name if use_wandb else None,
 | 
					            run_name=wandb_run_name if use_wandb else None,
 | 
				
			||||||
            gradient_checkpointing=gradient_checkpointing,
 | 
					            gradient_checkpointing=gradient_checkpointing,
 | 
				
			||||||
 | 
					            gradient_checkpointing_kwargs={"use_reentrant": False} if gradient_checkpointing else None,
 | 
				
			||||||
            ddp_backend="ccl" if ddp else None,
 | 
					            ddp_backend="ccl" if ddp else None,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    if ddp:
 | 
					 | 
				
			||||||
        from accelerate.state import PartialState
 | 
					 | 
				
			||||||
        args.distributed_state = PartialState(cpu=True, backend=args.ddp_backend)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    trainer = transformers.Trainer(
 | 
					    trainer = transformers.Trainer(
 | 
				
			||||||
        model=model,
 | 
					        model=model,
 | 
				
			||||||
| 
						 | 
					@ -351,3 +355,4 @@ def train(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    fire.Fire(train)
 | 
					    fire.Fire(train)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue