LLM: update DeepSpeed AutoTP example with GPU memory optimization (#9823)
This commit is contained in:
		
							parent
							
								
									5ba1dc38d4
								
							
						
					
					
						commit
						294fd32787
					
				
					 4 changed files with 88 additions and 16 deletions
				
			
		| 
						 | 
				
			
			@ -26,15 +26,28 @@ conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
 | 
			
		|||
> **Important**: IPEX 2.1.10+xpu requires Intel® oneAPI Base Toolkit's version == 2024.0. Please make sure you have installed the correct version.
 | 
			
		||||
 | 
			
		||||
### 2. Run tensor parallel inference on multiple GPUs
 | 
			
		||||
Here, we provide example usages on different models and different hardwares. Please refer to the appropriate script based on your model and device:
 | 
			
		||||
Here, we separate inference process into two stages. First, convert to deepspeed model and apply bigdl-llm optimization on CPU. Then, utilize XPU as DeepSpeed accelerator to inference. In this way, a *X*B model saved in 16-bit will requires approximately 0.5*X* GB total GPU memory in the whole process. For example, if you select to use two GPUs, 0.25*X* GB memory is required per GPU.
 | 
			
		||||
 | 
			
		||||
#### Llama2 series
 | 
			
		||||
<details><summary>Show LLaMA2-70B example</summary>
 | 
			
		||||
Run LLaMA2-70B on four Intel Data Center GPU Max 1550
 | 
			
		||||
Please select the appropriate model size based on the capabilities of your machine.
 | 
			
		||||
 | 
			
		||||
We provide example usages on different models and different hardwares as following:
 | 
			
		||||
 | 
			
		||||
- Run LLaMA2-70B on one card of Intel Data Center GPU Max 1550
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
bash run_llama2_70b_pvc_1550_4_card.sh
 | 
			
		||||
bash run_llama2_70b_pvc_1550_1_card.sh
 | 
			
		||||
```
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
> **Note**:If you may want to select only part of GPUs on your machine, please change `ZE_AFFINITY_MASK` and `NUM_GPUS` to your prefer value.
 | 
			
		||||
> **Note**: You could change `ZE_AFFINITY_MASK` and `NUM_GPUS` according to your requirements.
 | 
			
		||||
 | 
			
		||||
- Run Vicuna-33B on two Intel Arc A770
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
bash run_vicuna_33b_arc_2_card.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
> **Note**: You could change `NUM_GPUS` to the number of GPUs you have on your machine.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### Known Issue
 | 
			
		||||
- In our example scripts, tcmalloc is enabled through `export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}` which speed up inference, but this may raise `munmap_chunk(): invalid pointer` error after finishing inference.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -29,6 +29,10 @@ def get_int_from_env(env_keys, default):
 | 
			
		|||
 
 | 
			
		||||
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
 | 
			
		||||
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
 | 
			
		||||
os.environ["RANK"] = str(local_rank)
 | 
			
		||||
os.environ["WORLD_SIZE"] = str(world_size)
 | 
			
		||||
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from bigdl.llm import optimize_model
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -38,6 +42,9 @@ import argparse
 | 
			
		|||
 | 
			
		||||
from transformers import AutoModelForCausalLM  # export AutoModelForCausalLM from transformers so that deepspeed use it
 | 
			
		||||
from transformers import LlamaTokenizer, AutoTokenizer
 | 
			
		||||
from deepspeed.accelerator.cpu_accelerator import CPU_Accelerator
 | 
			
		||||
from deepspeed.accelerator import set_accelerator, get_accelerator
 | 
			
		||||
from intel_extension_for_deepspeed import XPU_Accelerator
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
 | 
			
		||||
| 
						 | 
				
			
			@ -52,7 +59,12 @@ if __name__ == '__main__':
 | 
			
		|||
    args = parser.parse_args()
 | 
			
		||||
    model_path = args.repo_id_or_model_path
 | 
			
		||||
 | 
			
		||||
    # First use CPU as accelerator
 | 
			
		||||
    # Convert to deepspeed model and apply bigdl-llm optimization on CPU to decrease GPU memory usage
 | 
			
		||||
    current_accel = CPU_Accelerator()
 | 
			
		||||
    set_accelerator(current_accel)
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(args.repo_id_or_model_path,
 | 
			
		||||
                                                 device_map={"": "cpu"},
 | 
			
		||||
                                                 low_cpu_mem_usage=True,
 | 
			
		||||
                                                 torch_dtype=torch.float16,
 | 
			
		||||
                                                 trust_remote_code=True,
 | 
			
		||||
| 
						 | 
				
			
			@ -65,14 +77,26 @@ if __name__ == '__main__':
 | 
			
		|||
        replace_method="auto",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # move model to cpu and use bigdl-llm `optimize_model` to convert the
 | 
			
		||||
    # model into optimized low bit format
 | 
			
		||||
    # convert the rest of the model into float16 to reduce allreduce traffic
 | 
			
		||||
    # Use bigdl-llm `optimize_model` to convert the model into optimized low bit format
 | 
			
		||||
    # Convert the rest of the model into float16 to reduce allreduce traffic
 | 
			
		||||
    model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4').to(torch.float16)
 | 
			
		||||
 | 
			
		||||
    # move model back to xpu
 | 
			
		||||
    # Next, use XPU as accelerator to speed up inference
 | 
			
		||||
    current_accel = XPU_Accelerator()
 | 
			
		||||
    set_accelerator(current_accel)
 | 
			
		||||
 | 
			
		||||
    # Move model back to xpu
 | 
			
		||||
    model = model.to(f'xpu:{local_rank}')
 | 
			
		||||
 | 
			
		||||
    # Modify backend related settings 
 | 
			
		||||
    if world_size > 1:
 | 
			
		||||
        get_accelerator().set_device(local_rank)
 | 
			
		||||
    dist_backend = get_accelerator().communication_backend_name()
 | 
			
		||||
    import deepspeed.comm.comm
 | 
			
		||||
    deepspeed.comm.comm.cdb = None
 | 
			
		||||
    from deepspeed.comm.comm import init_distributed
 | 
			
		||||
    init_distributed()
 | 
			
		||||
 | 
			
		||||
    print(model)
 | 
			
		||||
 | 
			
		||||
    # Load tokenizer
 | 
			
		||||
| 
						 | 
				
			
			@ -80,9 +104,7 @@ if __name__ == '__main__':
 | 
			
		|||
 | 
			
		||||
    # Generate predicted tokens
 | 
			
		||||
    with torch.inference_mode():
 | 
			
		||||
        # prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
 | 
			
		||||
        prompt = args.prompt
 | 
			
		||||
        # input_str = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:\n"
 | 
			
		||||
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}')
 | 
			
		||||
        # ipex model needs a warmup, then inference time can be accurate
 | 
			
		||||
        output = model.generate(input_ids,
 | 
			
		||||
| 
						 | 
				
			
			@ -108,3 +130,5 @@ if __name__ == '__main__':
 | 
			
		|||
            print(prompt)
 | 
			
		||||
            print('-'*20, 'Output', '-'*20)
 | 
			
		||||
            print(output_str)
 | 
			
		||||
    deepspeed.comm.destroy_process_group()
 | 
			
		||||
    print("process group destroyed, exiting...")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,12 +14,14 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
export ZE_AFFINITY_MASK="0,1,2,3,4,5,6,7" # specify the used GPU
 | 
			
		||||
NUM_GPUS=8 # number of used GPU
 | 
			
		||||
export ZE_AFFINITY_MASK="0,1" # specify the used GPU
 | 
			
		||||
NUM_GPUS=2 # number of used GPU
 | 
			
		||||
export MASTER_ADDR=127.0.0.1
 | 
			
		||||
export FI_PROVIDER=tcp
 | 
			
		||||
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}
 | 
			
		||||
export CCL_ATL_TRANSPORT=ofi
 | 
			
		||||
export CCL_ZE_IPC_EXCHANGE=sockets
 | 
			
		||||
 | 
			
		||||
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}
 | 
			
		||||
basekit_root=/opt/intel/oneapi
 | 
			
		||||
source $basekit_root/setvars.sh --force
 | 
			
		||||
source $basekit_root/ccl/latest/env/vars.sh --force
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,33 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
export MASTER_ADDR=127.0.0.1
 | 
			
		||||
export FI_PROVIDER=tcp
 | 
			
		||||
export CCL_ATL_TRANSPORT=ofi
 | 
			
		||||
export CCL_ZE_IPC_EXCHANGE=sockets
 | 
			
		||||
 | 
			
		||||
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}
 | 
			
		||||
basekit_root=/opt/intel/oneapi
 | 
			
		||||
source $basekit_root/setvars.sh --force
 | 
			
		||||
source $basekit_root/ccl/latest/env/vars.sh --force
 | 
			
		||||
 | 
			
		||||
NUM_GPUS=2 # number of used GPU
 | 
			
		||||
export USE_XETLA=OFF
 | 
			
		||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
 | 
			
		||||
export TORCH_LLM_ALLREDUCE=0 # Different from PVC
 | 
			
		||||
 | 
			
		||||
mpirun -np $NUM_GPUS --prepend-rank \
 | 
			
		||||
    python deepspeed_autotp.py --repo-id-or-model-path 'lmsys/vicuna-33b-v1.3'
 | 
			
		||||
		Loading…
	
		Reference in a new issue