LLM: update DeepSpeed AutoTP example with GPU memory optimization (#9823)
This commit is contained in:
parent
5ba1dc38d4
commit
294fd32787
4 changed files with 88 additions and 16 deletions
|
|
@ -26,15 +26,28 @@ conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
|
||||||
> **Important**: IPEX 2.1.10+xpu requires Intel® oneAPI Base Toolkit's version == 2024.0. Please make sure you have installed the correct version.
|
> **Important**: IPEX 2.1.10+xpu requires Intel® oneAPI Base Toolkit's version == 2024.0. Please make sure you have installed the correct version.
|
||||||
|
|
||||||
### 2. Run tensor parallel inference on multiple GPUs
|
### 2. Run tensor parallel inference on multiple GPUs
|
||||||
Here, we provide example usages on different models and different hardwares. Please refer to the appropriate script based on your model and device:
|
Here, we separate inference process into two stages. First, convert to deepspeed model and apply bigdl-llm optimization on CPU. Then, utilize XPU as DeepSpeed accelerator to inference. In this way, a *X*B model saved in 16-bit will requires approximately 0.5*X* GB total GPU memory in the whole process. For example, if you select to use two GPUs, 0.25*X* GB memory is required per GPU.
|
||||||
|
|
||||||
#### Llama2 series
|
Please select the appropriate model size based on the capabilities of your machine.
|
||||||
<details><summary>Show LLaMA2-70B example</summary>
|
|
||||||
Run LLaMA2-70B on four Intel Data Center GPU Max 1550
|
We provide example usages on different models and different hardwares as following:
|
||||||
|
|
||||||
|
- Run LLaMA2-70B on one card of Intel Data Center GPU Max 1550
|
||||||
|
|
||||||
```
|
```
|
||||||
bash run_llama2_70b_pvc_1550_4_card.sh
|
bash run_llama2_70b_pvc_1550_1_card.sh
|
||||||
```
|
```
|
||||||
</details>
|
|
||||||
|
|
||||||
> **Note**:If you may want to select only part of GPUs on your machine, please change `ZE_AFFINITY_MASK` and `NUM_GPUS` to your prefer value.
|
> **Note**: You could change `ZE_AFFINITY_MASK` and `NUM_GPUS` according to your requirements.
|
||||||
|
|
||||||
|
- Run Vicuna-33B on two Intel Arc A770
|
||||||
|
|
||||||
|
```
|
||||||
|
bash run_vicuna_33b_arc_2_card.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note**: You could change `NUM_GPUS` to the number of GPUs you have on your machine.
|
||||||
|
|
||||||
|
|
||||||
|
### Known Issue
|
||||||
|
- In our example scripts, tcmalloc is enabled through `export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}` which speed up inference, but this may raise `munmap_chunk(): invalid pointer` error after finishing inference.
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,10 @@ def get_int_from_env(env_keys, default):
|
||||||
|
|
||||||
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
|
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
|
||||||
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
|
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
|
||||||
|
os.environ["RANK"] = str(local_rank)
|
||||||
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
|
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
|
||||||
|
|
||||||
|
|
||||||
from bigdl.llm import optimize_model
|
from bigdl.llm import optimize_model
|
||||||
|
|
||||||
|
|
@ -38,6 +42,9 @@ import argparse
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM # export AutoModelForCausalLM from transformers so that deepspeed use it
|
from transformers import AutoModelForCausalLM # export AutoModelForCausalLM from transformers so that deepspeed use it
|
||||||
from transformers import LlamaTokenizer, AutoTokenizer
|
from transformers import LlamaTokenizer, AutoTokenizer
|
||||||
|
from deepspeed.accelerator.cpu_accelerator import CPU_Accelerator
|
||||||
|
from deepspeed.accelerator import set_accelerator, get_accelerator
|
||||||
|
from intel_extension_for_deepspeed import XPU_Accelerator
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
|
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
|
||||||
|
|
@ -52,7 +59,12 @@ if __name__ == '__main__':
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
model_path = args.repo_id_or_model_path
|
model_path = args.repo_id_or_model_path
|
||||||
|
|
||||||
|
# First use CPU as accelerator
|
||||||
|
# Convert to deepspeed model and apply bigdl-llm optimization on CPU to decrease GPU memory usage
|
||||||
|
current_accel = CPU_Accelerator()
|
||||||
|
set_accelerator(current_accel)
|
||||||
model = AutoModelForCausalLM.from_pretrained(args.repo_id_or_model_path,
|
model = AutoModelForCausalLM.from_pretrained(args.repo_id_or_model_path,
|
||||||
|
device_map={"": "cpu"},
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
|
@ -65,14 +77,26 @@ if __name__ == '__main__':
|
||||||
replace_method="auto",
|
replace_method="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
# move model to cpu and use bigdl-llm `optimize_model` to convert the
|
# Use bigdl-llm `optimize_model` to convert the model into optimized low bit format
|
||||||
# model into optimized low bit format
|
# Convert the rest of the model into float16 to reduce allreduce traffic
|
||||||
# convert the rest of the model into float16 to reduce allreduce traffic
|
|
||||||
model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4').to(torch.float16)
|
model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4').to(torch.float16)
|
||||||
|
|
||||||
# move model back to xpu
|
# Next, use XPU as accelerator to speed up inference
|
||||||
|
current_accel = XPU_Accelerator()
|
||||||
|
set_accelerator(current_accel)
|
||||||
|
|
||||||
|
# Move model back to xpu
|
||||||
model = model.to(f'xpu:{local_rank}')
|
model = model.to(f'xpu:{local_rank}')
|
||||||
|
|
||||||
|
# Modify backend related settings
|
||||||
|
if world_size > 1:
|
||||||
|
get_accelerator().set_device(local_rank)
|
||||||
|
dist_backend = get_accelerator().communication_backend_name()
|
||||||
|
import deepspeed.comm.comm
|
||||||
|
deepspeed.comm.comm.cdb = None
|
||||||
|
from deepspeed.comm.comm import init_distributed
|
||||||
|
init_distributed()
|
||||||
|
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
|
|
@ -80,9 +104,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# Generate predicted tokens
|
# Generate predicted tokens
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
|
|
||||||
prompt = args.prompt
|
prompt = args.prompt
|
||||||
# input_str = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:\n"
|
|
||||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}')
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}')
|
||||||
# ipex model needs a warmup, then inference time can be accurate
|
# ipex model needs a warmup, then inference time can be accurate
|
||||||
output = model.generate(input_ids,
|
output = model.generate(input_ids,
|
||||||
|
|
@ -108,3 +130,5 @@ if __name__ == '__main__':
|
||||||
print(prompt)
|
print(prompt)
|
||||||
print('-'*20, 'Output', '-'*20)
|
print('-'*20, 'Output', '-'*20)
|
||||||
print(output_str)
|
print(output_str)
|
||||||
|
deepspeed.comm.destroy_process_group()
|
||||||
|
print("process group destroyed, exiting...")
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
export ZE_AFFINITY_MASK="0,1,2,3,4,5,6,7" # specify the used GPU
|
export ZE_AFFINITY_MASK="0,1" # specify the used GPU
|
||||||
NUM_GPUS=8 # number of used GPU
|
NUM_GPUS=2 # number of used GPU
|
||||||
export MASTER_ADDR=127.0.0.1
|
export MASTER_ADDR=127.0.0.1
|
||||||
export FI_PROVIDER=tcp
|
export FI_PROVIDER=tcp
|
||||||
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}
|
export CCL_ATL_TRANSPORT=ofi
|
||||||
|
export CCL_ZE_IPC_EXCHANGE=sockets
|
||||||
|
|
||||||
|
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}
|
||||||
basekit_root=/opt/intel/oneapi
|
basekit_root=/opt/intel/oneapi
|
||||||
source $basekit_root/setvars.sh --force
|
source $basekit_root/setvars.sh --force
|
||||||
source $basekit_root/ccl/latest/env/vars.sh --force
|
source $basekit_root/ccl/latest/env/vars.sh --force
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
export MASTER_ADDR=127.0.0.1
|
||||||
|
export FI_PROVIDER=tcp
|
||||||
|
export CCL_ATL_TRANSPORT=ofi
|
||||||
|
export CCL_ZE_IPC_EXCHANGE=sockets
|
||||||
|
|
||||||
|
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}
|
||||||
|
basekit_root=/opt/intel/oneapi
|
||||||
|
source $basekit_root/setvars.sh --force
|
||||||
|
source $basekit_root/ccl/latest/env/vars.sh --force
|
||||||
|
|
||||||
|
NUM_GPUS=2 # number of used GPU
|
||||||
|
export USE_XETLA=OFF
|
||||||
|
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
|
||||||
|
export TORCH_LLM_ALLREDUCE=0 # Different from PVC
|
||||||
|
|
||||||
|
mpirun -np $NUM_GPUS --prepend-rank \
|
||||||
|
python deepspeed_autotp.py --repo-id-or-model-path 'lmsys/vicuna-33b-v1.3'
|
||||||
Loading…
Reference in a new issue