From 294fd3278750bb80f8bc3fbbf83ff3b09bcd4ab3 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Tue, 9 Jan 2024 09:22:49 +0800 Subject: [PATCH] LLM: update DeepSpeed AutoTP example with GPU memory optimization (#9823) --- .../example/GPU/Deepspeed-AutoTP/README.md | 27 ++++++++++---- .../GPU/Deepspeed-AutoTP/deepspeed_autotp.py | 36 +++++++++++++++---- ...d.sh => run_llama2_70b_pvc_1550_1_card.sh} | 8 +++-- .../run_vicuna_33b_arc_2_card.sh | 33 +++++++++++++++++ 4 files changed, 88 insertions(+), 16 deletions(-) rename python/llm/example/GPU/Deepspeed-AutoTP/{run_llama2_70b_pvc_1550_4_card.sh => run_llama2_70b_pvc_1550_1_card.sh} (88%) create mode 100644 python/llm/example/GPU/Deepspeed-AutoTP/run_vicuna_33b_arc_2_card.sh diff --git a/python/llm/example/GPU/Deepspeed-AutoTP/README.md b/python/llm/example/GPU/Deepspeed-AutoTP/README.md index fe2804e5..4da831f7 100644 --- a/python/llm/example/GPU/Deepspeed-AutoTP/README.md +++ b/python/llm/example/GPU/Deepspeed-AutoTP/README.md @@ -26,15 +26,28 @@ conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc > **Important**: IPEX 2.1.10+xpu requires IntelĀ® oneAPI Base Toolkit's version == 2024.0. Please make sure you have installed the correct version. ### 2. Run tensor parallel inference on multiple GPUs -Here, we provide example usages on different models and different hardwares. Please refer to the appropriate script based on your model and device: +Here, we separate inference process into two stages. First, convert to deepspeed model and apply bigdl-llm optimization on CPU. Then, utilize XPU as DeepSpeed accelerator to inference. In this way, a *X*B model saved in 16-bit will requires approximately 0.5*X* GB total GPU memory in the whole process. For example, if you select to use two GPUs, 0.25*X* GB memory is required per GPU. -#### Llama2 series -
Show LLaMA2-70B example -Run LLaMA2-70B on four Intel Data Center GPU Max 1550 +Please select the appropriate model size based on the capabilities of your machine. + +We provide example usages on different models and different hardwares as following: + +- Run LLaMA2-70B on one card of Intel Data Center GPU Max 1550 ``` -bash run_llama2_70b_pvc_1550_4_card.sh +bash run_llama2_70b_pvc_1550_1_card.sh ``` -
-> **Note**:If you may want to select only part of GPUs on your machine, please change `ZE_AFFINITY_MASK` and `NUM_GPUS` to your prefer value. +> **Note**: You could change `ZE_AFFINITY_MASK` and `NUM_GPUS` according to your requirements. + +- Run Vicuna-33B on two Intel Arc A770 + +``` +bash run_vicuna_33b_arc_2_card.sh +``` + +> **Note**: You could change `NUM_GPUS` to the number of GPUs you have on your machine. + + +### Known Issue +- In our example scripts, tcmalloc is enabled through `export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}` which speed up inference, but this may raise `munmap_chunk(): invalid pointer` error after finishing inference. diff --git a/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py b/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py index 381217c4..13ee9d65 100644 --- a/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py +++ b/python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py @@ -29,6 +29,10 @@ def get_int_from_env(env_keys, default): local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0") world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1") +os.environ["RANK"] = str(local_rank) +os.environ["WORLD_SIZE"] = str(world_size) +os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") + from bigdl.llm import optimize_model @@ -38,6 +42,9 @@ import argparse from transformers import AutoModelForCausalLM # export AutoModelForCausalLM from transformers so that deepspeed use it from transformers import LlamaTokenizer, AutoTokenizer +from deepspeed.accelerator.cpu_accelerator import CPU_Accelerator +from deepspeed.accelerator import set_accelerator, get_accelerator +from intel_extension_for_deepspeed import XPU_Accelerator if __name__ == '__main__': parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') @@ -52,7 +59,12 @@ if __name__ == '__main__': args = parser.parse_args() model_path = args.repo_id_or_model_path + # First use CPU as accelerator + # Convert to deepspeed model and apply bigdl-llm optimization on CPU to decrease GPU memory usage + current_accel = CPU_Accelerator() + set_accelerator(current_accel) model = AutoModelForCausalLM.from_pretrained(args.repo_id_or_model_path, + device_map={"": "cpu"}, low_cpu_mem_usage=True, torch_dtype=torch.float16, trust_remote_code=True, @@ -65,14 +77,26 @@ if __name__ == '__main__': replace_method="auto", ) - # move model to cpu and use bigdl-llm `optimize_model` to convert the - # model into optimized low bit format - # convert the rest of the model into float16 to reduce allreduce traffic + # Use bigdl-llm `optimize_model` to convert the model into optimized low bit format + # Convert the rest of the model into float16 to reduce allreduce traffic model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4').to(torch.float16) - # move model back to xpu + # Next, use XPU as accelerator to speed up inference + current_accel = XPU_Accelerator() + set_accelerator(current_accel) + + # Move model back to xpu model = model.to(f'xpu:{local_rank}') + # Modify backend related settings + if world_size > 1: + get_accelerator().set_device(local_rank) + dist_backend = get_accelerator().communication_backend_name() + import deepspeed.comm.comm + deepspeed.comm.comm.cdb = None + from deepspeed.comm.comm import init_distributed + init_distributed() + print(model) # Load tokenizer @@ -80,9 +104,7 @@ if __name__ == '__main__': # Generate predicted tokens with torch.inference_mode(): - # prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) prompt = args.prompt - # input_str = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:\n" input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}') # ipex model needs a warmup, then inference time can be accurate output = model.generate(input_ids, @@ -108,3 +130,5 @@ if __name__ == '__main__': print(prompt) print('-'*20, 'Output', '-'*20) print(output_str) + deepspeed.comm.destroy_process_group() + print("process group destroyed, exiting...") diff --git a/python/llm/example/GPU/Deepspeed-AutoTP/run_llama2_70b_pvc_1550_4_card.sh b/python/llm/example/GPU/Deepspeed-AutoTP/run_llama2_70b_pvc_1550_1_card.sh similarity index 88% rename from python/llm/example/GPU/Deepspeed-AutoTP/run_llama2_70b_pvc_1550_4_card.sh rename to python/llm/example/GPU/Deepspeed-AutoTP/run_llama2_70b_pvc_1550_1_card.sh index 6c91ffd5..380e1a58 100644 --- a/python/llm/example/GPU/Deepspeed-AutoTP/run_llama2_70b_pvc_1550_4_card.sh +++ b/python/llm/example/GPU/Deepspeed-AutoTP/run_llama2_70b_pvc_1550_1_card.sh @@ -14,12 +14,14 @@ # limitations under the License. # -export ZE_AFFINITY_MASK="0,1,2,3,4,5,6,7" # specify the used GPU -NUM_GPUS=8 # number of used GPU +export ZE_AFFINITY_MASK="0,1" # specify the used GPU +NUM_GPUS=2 # number of used GPU export MASTER_ADDR=127.0.0.1 export FI_PROVIDER=tcp -export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD} +export CCL_ATL_TRANSPORT=ofi +export CCL_ZE_IPC_EXCHANGE=sockets +export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD} basekit_root=/opt/intel/oneapi source $basekit_root/setvars.sh --force source $basekit_root/ccl/latest/env/vars.sh --force diff --git a/python/llm/example/GPU/Deepspeed-AutoTP/run_vicuna_33b_arc_2_card.sh b/python/llm/example/GPU/Deepspeed-AutoTP/run_vicuna_33b_arc_2_card.sh new file mode 100644 index 00000000..ca0697a6 --- /dev/null +++ b/python/llm/example/GPU/Deepspeed-AutoTP/run_vicuna_33b_arc_2_card.sh @@ -0,0 +1,33 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +export MASTER_ADDR=127.0.0.1 +export FI_PROVIDER=tcp +export CCL_ATL_TRANSPORT=ofi +export CCL_ZE_IPC_EXCHANGE=sockets + +export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD} +basekit_root=/opt/intel/oneapi +source $basekit_root/setvars.sh --force +source $basekit_root/ccl/latest/env/vars.sh --force + +NUM_GPUS=2 # number of used GPU +export USE_XETLA=OFF +export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2 +export TORCH_LLM_ALLREDUCE=0 # Different from PVC + +mpirun -np $NUM_GPUS --prepend-rank \ + python deepspeed_autotp.py --repo-id-or-model-path 'lmsys/vicuna-33b-v1.3'