LLM: update DeepSpeed AutoTP example with GPU memory optimization (#9823)

This commit is contained in:
binbin Deng 2024-01-09 09:22:49 +08:00 committed by GitHub
parent 5ba1dc38d4
commit 294fd32787
4 changed files with 88 additions and 16 deletions

View file

@ -26,15 +26,28 @@ conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
> **Important**: IPEX 2.1.10+xpu requires Intel® oneAPI Base Toolkit's version == 2024.0. Please make sure you have installed the correct version.
### 2. Run tensor parallel inference on multiple GPUs
Here, we provide example usages on different models and different hardwares. Please refer to the appropriate script based on your model and device:
Here, we separate inference process into two stages. First, convert to deepspeed model and apply bigdl-llm optimization on CPU. Then, utilize XPU as DeepSpeed accelerator to inference. In this way, a *X*B model saved in 16-bit will requires approximately 0.5*X* GB total GPU memory in the whole process. For example, if you select to use two GPUs, 0.25*X* GB memory is required per GPU.
#### Llama2 series
<details><summary>Show LLaMA2-70B example</summary>
Run LLaMA2-70B on four Intel Data Center GPU Max 1550
Please select the appropriate model size based on the capabilities of your machine.
We provide example usages on different models and different hardwares as following:
- Run LLaMA2-70B on one card of Intel Data Center GPU Max 1550
```
bash run_llama2_70b_pvc_1550_4_card.sh
bash run_llama2_70b_pvc_1550_1_card.sh
```
</details>
> **Note**:If you may want to select only part of GPUs on your machine, please change `ZE_AFFINITY_MASK` and `NUM_GPUS` to your prefer value.
> **Note**: You could change `ZE_AFFINITY_MASK` and `NUM_GPUS` according to your requirements.
- Run Vicuna-33B on two Intel Arc A770
```
bash run_vicuna_33b_arc_2_card.sh
```
> **Note**: You could change `NUM_GPUS` to the number of GPUs you have on your machine.
### Known Issue
- In our example scripts, tcmalloc is enabled through `export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}` which speed up inference, but this may raise `munmap_chunk(): invalid pointer` error after finishing inference.

View file

@ -29,6 +29,10 @@ def get_int_from_env(env_keys, default):
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
from bigdl.llm import optimize_model
@ -38,6 +42,9 @@ import argparse
from transformers import AutoModelForCausalLM # export AutoModelForCausalLM from transformers so that deepspeed use it
from transformers import LlamaTokenizer, AutoTokenizer
from deepspeed.accelerator.cpu_accelerator import CPU_Accelerator
from deepspeed.accelerator import set_accelerator, get_accelerator
from intel_extension_for_deepspeed import XPU_Accelerator
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
@ -52,7 +59,12 @@ if __name__ == '__main__':
args = parser.parse_args()
model_path = args.repo_id_or_model_path
# First use CPU as accelerator
# Convert to deepspeed model and apply bigdl-llm optimization on CPU to decrease GPU memory usage
current_accel = CPU_Accelerator()
set_accelerator(current_accel)
model = AutoModelForCausalLM.from_pretrained(args.repo_id_or_model_path,
device_map={"": "cpu"},
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
trust_remote_code=True,
@ -65,14 +77,26 @@ if __name__ == '__main__':
replace_method="auto",
)
# move model to cpu and use bigdl-llm `optimize_model` to convert the
# model into optimized low bit format
# convert the rest of the model into float16 to reduce allreduce traffic
# Use bigdl-llm `optimize_model` to convert the model into optimized low bit format
# Convert the rest of the model into float16 to reduce allreduce traffic
model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4').to(torch.float16)
# move model back to xpu
# Next, use XPU as accelerator to speed up inference
current_accel = XPU_Accelerator()
set_accelerator(current_accel)
# Move model back to xpu
model = model.to(f'xpu:{local_rank}')
# Modify backend related settings
if world_size > 1:
get_accelerator().set_device(local_rank)
dist_backend = get_accelerator().communication_backend_name()
import deepspeed.comm.comm
deepspeed.comm.comm.cdb = None
from deepspeed.comm.comm import init_distributed
init_distributed()
print(model)
# Load tokenizer
@ -80,9 +104,7 @@ if __name__ == '__main__':
# Generate predicted tokens
with torch.inference_mode():
# prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
prompt = args.prompt
# input_str = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:\n"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}')
# ipex model needs a warmup, then inference time can be accurate
output = model.generate(input_ids,
@ -108,3 +130,5 @@ if __name__ == '__main__':
print(prompt)
print('-'*20, 'Output', '-'*20)
print(output_str)
deepspeed.comm.destroy_process_group()
print("process group destroyed, exiting...")

View file

@ -14,12 +14,14 @@
# limitations under the License.
#
export ZE_AFFINITY_MASK="0,1,2,3,4,5,6,7" # specify the used GPU
NUM_GPUS=8 # number of used GPU
export ZE_AFFINITY_MASK="0,1" # specify the used GPU
NUM_GPUS=2 # number of used GPU
export MASTER_ADDR=127.0.0.1
export FI_PROVIDER=tcp
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}
export CCL_ATL_TRANSPORT=ofi
export CCL_ZE_IPC_EXCHANGE=sockets
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}
basekit_root=/opt/intel/oneapi
source $basekit_root/setvars.sh --force
source $basekit_root/ccl/latest/env/vars.sh --force

View file

@ -0,0 +1,33 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
export MASTER_ADDR=127.0.0.1
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
export CCL_ZE_IPC_EXCHANGE=sockets
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}
basekit_root=/opt/intel/oneapi
source $basekit_root/setvars.sh --force
source $basekit_root/ccl/latest/env/vars.sh --force
NUM_GPUS=2 # number of used GPU
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
export TORCH_LLM_ALLREDUCE=0 # Different from PVC
mpirun -np $NUM_GPUS --prepend-rank \
python deepspeed_autotp.py --repo-id-or-model-path 'lmsys/vicuna-33b-v1.3'