Add deepsped-autoTP-Fastapi serving (#10748)

* add deepsped-autoTP-Fastapi serving

* add readme

* add license

* update

* update

* fix
This commit is contained in:
ZehuaCao 2024-04-16 14:03:23 +08:00 committed by GitHub
parent a7c12020b4
commit 599a88db53
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 274 additions and 2 deletions

View file

@ -0,0 +1,84 @@
# Run IPEX-LLM serving on Multiple Intel GPUs using DeepSpeed AutoTP and FastApi
This example demonstrates how to run IPEX-LLM serving on multiple [Intel GPUs](../README.md) by leveraging DeepSpeed AutoTP.
## Requirements
To run this example with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine.
## Example
### 1. Install
```bash
conda create -n llm python=3.11
conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
# configures OneAPI environment variables
source /opt/intel/oneapi/setvars.sh
pip install git+https://github.com/microsoft/DeepSpeed.git@ed8aed5
pip install git+https://github.com/intel/intel-extension-for-deepspeed.git@0eb734b
pip install mpi4py fastapi uvicorn
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
```
> **Important**: IPEX 2.1.10+xpu requires Intel® oneAPI Base Toolkit's version == 2024.0. Please make sure you have installed the correct version.
### 2. Run tensor parallel inference on multiple GPUs
When we run the model in a distributed manner across two GPUs, the memory consumption of each GPU is only half of what it was originally, and the GPUs can work simultaneously during inference computation.
We provide example usage for `Llama-2-7b-chat-hf` model running on Arc A770
Run Llama-2-7b-chat-hf on two Intel Arc A770:
```bash
# Before run this script, you should adjust the YOUR_REPO_ID_OR_MODEL_PATH in last line
# If you want to change server port, you can set port parameter in last line
bash run_llama2_7b_chat_hf_arc_2_card.sh
```
If you successfully run the serving, you can get output like this:
```bash
[0] INFO: Started server process [120071]
[0] INFO: Waiting for application startup.
[0] INFO: Application startup complete.
[0] INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
```
> **Note**: You could change `NUM_GPUS` to the number of GPUs you have on your machine. And you could also specify other low bit optimizations through `--low-bit`.
### 3. Sample Input and Output
We can use `curl` to test serving api
```bash
# Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy.
export http_proxy=
export https_proxy=
curl -X 'POST' \
'http://127.0.0.1:8000/generate/' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"prompt": "What is AI?",
"n_predict": 32
}'
```
And you should get output like this:
```json
{
"generated_text": "What is AI? Artificial intelligence (AI) refers to the development of computer systems able to perform tasks that would normally require human intelligence, such as visual perception, speech",
"generate_time": "0.45149803161621094s"
}
```
**Important**: The first token latency is much larger than rest token latency, you could use [our benchmark tool](https://github.com/intel-analytics/ipex-llm/blob/main/python/llm/dev/benchmark/README.md) to obtain more details about first and rest token latency.

View file

@ -0,0 +1,35 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
export MASTER_ADDR=127.0.0.1
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
export CCL_ZE_IPC_EXCHANGE=sockets
export no_proxy=localhost
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}
basekit_root=/opt/intel/oneapi
source $basekit_root/setvars.sh --force
source $basekit_root/ccl/latest/env/vars.sh --force
NUM_GPUS=2 # number of used GPU
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
export TORCH_LLM_ALLREDUCE=0
mpirun -np $NUM_GPUS --prepend-rank \
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000

View file

@ -0,0 +1,149 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import torch
import transformers
import time
import argparse
import torch.distributed as dist
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
def get_int_from_env(env_keys, default):
"""Returns the first positive env value found in the `env_keys` list or the default."""
for e in env_keys:
val = int(os.environ.get(e, -1))
if val >= 0:
return val
return int(default)
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
global model, tokenizer
def load_model(model_path, low_bit):
from ipex_llm import optimize_model
import torch
import time
import argparse
from transformers import AutoModelForCausalLM # export AutoModelForCausalLM from transformers so that deepspeed use it
from transformers import LlamaTokenizer, AutoTokenizer
import deepspeed
from deepspeed.accelerator.cpu_accelerator import CPU_Accelerator
from deepspeed.accelerator import set_accelerator, get_accelerator
from intel_extension_for_deepspeed import XPU_Accelerator
# First use CPU as accelerator
# Convert to deepspeed model and apply IPEX-LLM optimization on CPU to decrease GPU memory usage
current_accel = CPU_Accelerator()
set_accelerator(current_accel)
global model, tokenizer
model = AutoModelForCausalLM.from_pretrained(model_path,
device_map={"": "cpu"},
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
trust_remote_code=True,
use_cache=True)
model = deepspeed.init_inference(
model,
mp_size=world_size,
dtype=torch.bfloat16,
replace_method="auto",
)
# Use IPEX-LLM `optimize_model` to convert the model into optimized low bit format
# Convert the rest of the model into float16 to reduce allreduce traffic
model = optimize_model(model.module.to(f'cpu'), low_bit=low_bit).to(torch.float16)
# Next, use XPU as accelerator to speed up inference
current_accel = XPU_Accelerator()
set_accelerator(current_accel)
# Move model back to xpu
model = model.to(f'xpu:{local_rank}')
# Modify backend related settings
if world_size > 1:
get_accelerator().set_device(local_rank)
dist_backend = get_accelerator().communication_backend_name()
import deepspeed.comm.comm
deepspeed.comm.comm.cdb = None
from deepspeed.comm.comm import init_distributed
init_distributed()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
def generate_text(prompt: str, n_predict: int = 32):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}')
output = model.generate(input_ids,
max_new_tokens=n_predict,
use_cache=True)
torch.xpu.synchronize()
return output
class PromptRequest(BaseModel):
prompt: str
n_predict: int = 32
app = FastAPI()
@app.post("/generate/")
async def generate(prompt_request: PromptRequest):
if local_rank == 0:
object_list = [prompt_request]
dist.broadcast_object_list(object_list, src=0)
start_time = time.time()
output = generate_text(object_list[0].prompt, object_list[0].n_predict)
generate_time = time.time() - start_time
output = output.cpu()
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
return {"generated_text": output_str, "generate_time": f'{generate_time:.3f}s'}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--low-bit', type=str, default='sym_int4',
help='The quantization type the model will convert to.')
parser.add_argument('--port', type=int, default=8000,
help='The port number on which the server will run.')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
low_bit = args.low_bit
load_model(model_path, low_bit)
if local_rank == 0:
uvicorn.run(app, host="0.0.0.0", port=args.port)
else:
while True:
object_list = [None]
dist.broadcast_object_list(object_list, src=0)
output = generate_text(object_list[0].prompt, object_list[0].n_predict)

View file

@ -3,6 +3,7 @@
This example demonstrates how to run IPEX-LLM optimized low-bit model on multiple [Intel GPUs](../README.md) by leveraging DeepSpeed AutoTP. This example demonstrates how to run IPEX-LLM optimized low-bit model on multiple [Intel GPUs](../README.md) by leveraging DeepSpeed AutoTP.
## Requirements ## Requirements
To run this example with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine. To run this example with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine.
## Example: ## Example:
@ -25,6 +26,7 @@ conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
> **Important**: IPEX 2.1.10+xpu requires Intel® oneAPI Base Toolkit's version == 2024.0. Please make sure you have installed the correct version. > **Important**: IPEX 2.1.10+xpu requires Intel® oneAPI Base Toolkit's version == 2024.0. Please make sure you have installed the correct version.
### 2. Run tensor parallel inference on multiple GPUs ### 2. Run tensor parallel inference on multiple GPUs
Here, we separate inference process into two stages. First, convert to deepspeed model and apply ipex-llm optimization on CPU. Then, utilize XPU as DeepSpeed accelerator to inference. In this way, a *X*B model saved in 16-bit will requires approximately 0.5*X* GB total GPU memory in the whole process. For example, if you select to use two GPUs, 0.25*X* GB memory is required per GPU. Here, we separate inference process into two stages. First, convert to deepspeed model and apply ipex-llm optimization on CPU. Then, utilize XPU as DeepSpeed accelerator to inference. In this way, a *X*B model saved in 16-bit will requires approximately 0.5*X* GB total GPU memory in the whole process. For example, if you select to use two GPUs, 0.25*X* GB memory is required per GPU.
Please select the appropriate model size based on the capabilities of your machine. Please select the appropriate model size based on the capabilities of your machine.
@ -33,7 +35,7 @@ We provide example usages on different models and different hardwares as followi
- Run LLaMA2-70B on one card of Intel Data Center GPU Max 1550 - Run LLaMA2-70B on one card of Intel Data Center GPU Max 1550
``` ```bash
bash run_llama2_70b_pvc_1550_1_card.sh bash run_llama2_70b_pvc_1550_1_card.sh
``` ```
@ -41,7 +43,7 @@ bash run_llama2_70b_pvc_1550_1_card.sh
- Run Vicuna-33B on two Intel Arc A770 - Run Vicuna-33B on two Intel Arc A770
``` ```bash
bash run_vicuna_33b_arc_2_card.sh bash run_vicuna_33b_arc_2_card.sh
``` ```
@ -62,4 +64,5 @@ bash run_vicuna_33b_arc_2_card.sh
**Important**: The first token latency is much larger than rest token latency, you could use [our benchmark tool](https://github.com/intel-analytics/ipex-llm/blob/main/python/llm/dev/benchmark/README.md) to obtain more details about first and rest token latency. **Important**: The first token latency is much larger than rest token latency, you could use [our benchmark tool](https://github.com/intel-analytics/ipex-llm/blob/main/python/llm/dev/benchmark/README.md) to obtain more details about first and rest token latency.
### Known Issue ### Known Issue
- In our example scripts, tcmalloc is enabled through `export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}` which speed up inference, but this may raise `munmap_chunk(): invalid pointer` error after finishing inference. - In our example scripts, tcmalloc is enabled through `export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so:${LD_PRELOAD}` which speed up inference, but this may raise `munmap_chunk(): invalid pointer` error after finishing inference.

View file

@ -7,6 +7,7 @@ This folder contains examples of running IPEX-LLM on Intel GPU:
- [LLM-Finetuning](LLM-Finetuning): running ***finetuning*** (such as LoRA, QLoRA, QA-LoRA, etc) using IPEX-LLM on Intel GPUs - [LLM-Finetuning](LLM-Finetuning): running ***finetuning*** (such as LoRA, QLoRA, QA-LoRA, etc) using IPEX-LLM on Intel GPUs
- [vLLM-Serving](vLLM-Serving): running ***vLLM*** serving framework on intel GPUs (with IPEX-LLM low-bit optimized models) - [vLLM-Serving](vLLM-Serving): running ***vLLM*** serving framework on intel GPUs (with IPEX-LLM low-bit optimized models)
- [Deepspeed-AutoTP](Deepspeed-AutoTP): running distributed inference using ***DeepSpeed AutoTP*** (with IPEX-LLM low-bit optimized models) on Intel GPUs - [Deepspeed-AutoTP](Deepspeed-AutoTP): running distributed inference using ***DeepSpeed AutoTP*** (with IPEX-LLM low-bit optimized models) on Intel GPUs
- [Deepspeed-AutoTP-FastApi](Deepspeed-AutoTP-FastApi): running distributed inference using ***DeepSpeed AutoTP*** and start serving with ***FastApi***(with IPEX-LLM low-bit optimized models) on Intel GPUs
- [LangChain](LangChain): running ***LangChain*** applications on IPEX-LLM - [LangChain](LangChain): running ***LangChain*** applications on IPEX-LLM
- [PyTorch-Models](PyTorch-Models): running any PyTorch model on IPEX-LLM (with "one-line code change") - [PyTorch-Models](PyTorch-Models): running any PyTorch model on IPEX-LLM (with "one-line code change")
- [Speculative-Decoding](Speculative-Decoding): running any ***Hugging Face Transformers*** model with ***self-speculative decoding*** on Intel GPUs - [Speculative-Decoding](Speculative-Decoding): running any ***Hugging Face Transformers*** model with ***self-speculative decoding*** on Intel GPUs