Upgrade to vLLM 0.6.6 (#12796)
* init * update engine init * fix serving load_in_low_bit problem * temp * temp * temp * temp * temp * fix * fixed * done * fix * fix all arguments * fix * fix throughput script * fix * fix * use official ipex-llm * Fix readme * fix --------- Co-authored-by: hzjane <a1015616934@qq.com>
This commit is contained in:
parent
f8ab833f74
commit
af693425f1
14 changed files with 1003 additions and 910 deletions
|
|
@ -1,27 +1,17 @@
|
|||
FROM intel/oneapi-basekit:2024.1.1-devel-ubuntu22.04
|
||||
# First stage: build oneccl
|
||||
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04 AS build
|
||||
|
||||
ARG http_proxy
|
||||
ARG https_proxy
|
||||
|
||||
ENV TZ=Asia/Shanghai
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
# To prevent RPC_TIMEOUT ERROR for the first request
|
||||
ENV VLLM_RPC_TIMEOUT=100000
|
||||
|
||||
|
||||
# Disable pip's cache behavior
|
||||
ARG PIP_NO_CACHE_DIR=false
|
||||
ADD ./gradio_web_server.patch /tmp/gradio_web_server.patch
|
||||
ADD ./oneccl-binding.patch /tmp/oneccl-binding.patch
|
||||
|
||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \
|
||||
echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \
|
||||
chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \
|
||||
rm /etc/apt/sources.list.d/intel-graphics.list && \
|
||||
wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \
|
||||
echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \
|
||||
chmod 644 /usr/share/keyrings/intel-graphics.gpg && \
|
||||
apt-get update && \
|
||||
ADD ./ccl_torch.patch /tmp/
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends curl wget git libunwind8-dev vim less && \
|
||||
ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && \
|
||||
env DEBIAN_FRONTEND=noninteractive apt-get update && \
|
||||
|
|
@ -39,65 +29,105 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
|
|||
python3 get-pip.py && \
|
||||
rm get-pip.py && \
|
||||
pip install --upgrade requests argparse urllib3 && \
|
||||
pip install --pre --upgrade ipex-llm[xpu,serving] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ && \
|
||||
apt-get install -y --no-install-recommends libfabric-dev wrk libaio-dev numactl && \
|
||||
# If we do not install this compute-runtime, we will fail the build later
|
||||
mkdir -p /tmp/neo && \
|
||||
cd /tmp/neo && \
|
||||
wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.5.6/intel-igc-core-2_2.5.6+18417_amd64.deb && \
|
||||
wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.5.6/intel-igc-opencl-2_2.5.6+18417_amd64.deb && \
|
||||
wget https://github.com/intel/compute-runtime/releases/download/24.52.32224.5/intel-level-zero-gpu-dbgsym_1.6.32224.5_amd64.ddeb && \
|
||||
wget https://github.com/intel/compute-runtime/releases/download/24.52.32224.5/intel-level-zero-gpu_1.6.32224.5_amd64.deb && \
|
||||
wget https://github.com/intel/compute-runtime/releases/download/24.52.32224.5/intel-opencl-icd-dbgsym_24.52.32224.5_amd64.ddeb && \
|
||||
wget https://github.com/intel/compute-runtime/releases/download/24.52.32224.5/intel-opencl-icd_24.52.32224.5_amd64.deb && \
|
||||
wget https://github.com/intel/compute-runtime/releases/download/24.52.32224.5/libigdgmm12_22.5.5_amd64.deb && \
|
||||
dpkg -i *.deb && \
|
||||
pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/test/xpu && \
|
||||
mkdir /build && \
|
||||
cd /build && \
|
||||
git clone https://github.com/intel/torch-ccl.git && \
|
||||
cd torch-ccl && \
|
||||
git checkout ccl_torch2.5.0+xpu && \
|
||||
git submodule sync && \
|
||||
git submodule update --init --recursive && \
|
||||
# This patch will enable build torch-ccl with pytorch 2.6 environment
|
||||
git apply /tmp/ccl_torch.patch && \
|
||||
USE_SYSTEM_ONECCL=ON COMPUTE_BACKEND=dpcpp python setup.py bdist_wheel
|
||||
# File path: /build/torch-ccl/dist/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl
|
||||
|
||||
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04
|
||||
|
||||
COPY --from=build /build/torch-ccl/dist/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl /opt/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl
|
||||
|
||||
ARG http_proxy
|
||||
ARG https_proxy
|
||||
|
||||
ENV TZ=Asia/Shanghai
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
# To prevent RPC_TIMEOUT ERROR for the first request
|
||||
ENV VLLM_RPC_TIMEOUT=100000
|
||||
|
||||
|
||||
# Disable pip's cache behavior
|
||||
ARG PIP_NO_CACHE_DIR=false
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends curl wget git libunwind8-dev vim less && \
|
||||
ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && \
|
||||
env DEBIAN_FRONTEND=noninteractive apt-get update && \
|
||||
apt-get install -y --no-install-recommends gnupg gpg-agent software-properties-common kmod && \
|
||||
# Add Python 3.11 PPA repository
|
||||
add-apt-repository ppa:deadsnakes/ppa -y && \
|
||||
apt-get install -y --no-install-recommends python3.11 git curl wget && \
|
||||
rm /usr/bin/python3 && \
|
||||
ln -s /usr/bin/python3.11 /usr/bin/python3 && \
|
||||
ln -s /usr/bin/python3 /usr/bin/python && \
|
||||
apt-get install -y --no-install-recommends python3-pip python3.11-dev python3-wheel python3.11-distutils && \
|
||||
wget https://bootstrap.pypa.io/get-pip.py -O get-pip.py && \
|
||||
python3 get-pip.py && \
|
||||
rm get-pip.py && \
|
||||
pip install --upgrade requests argparse urllib3 && \
|
||||
pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/test/xpu && \
|
||||
pip install transformers_stream_generator einops tiktoken && \
|
||||
pip install --upgrade colorama && \
|
||||
# Download all-in-one benchmark and examples
|
||||
git clone https://github.com/intel-analytics/ipex-llm && \
|
||||
# The following comment segment is used when building from source...
|
||||
# cd ipex-llm && \
|
||||
# git fetch origin pull/12338/head:local_pr && \
|
||||
# git checkout local_pr && \
|
||||
# pip uninstall -y ipex-llm && \
|
||||
# cd python/llm && \
|
||||
# python setup.py install && \
|
||||
# cd ../../../ && \
|
||||
git clone https://github.com/intel/ipex-llm.git && \
|
||||
cp -r ./ipex-llm/python/llm/dev/benchmark/ ./benchmark && \
|
||||
cp -r ./ipex-llm/python/llm/example/GPU/HuggingFace/LLM ./examples && \
|
||||
cp -r ./ipex-llm/python/llm/example/GPU/vLLM-Serving/ ./vLLM-Serving && \
|
||||
rm -rf ./ipex-llm && \
|
||||
# Install vllm dependencies
|
||||
pip install --upgrade fastapi && \
|
||||
pip install --upgrade "uvicorn[standard]" && \
|
||||
# Download vLLM-Serving
|
||||
cp -r ./ipex-llm/python/llm/example/GPU/vLLM-Serving/ ./vLLM-Serving && \
|
||||
rm -rf ./ipex-llm && \
|
||||
# Install torch-ccl
|
||||
cd /tmp/ && \
|
||||
pip install torch==2.1.0.post2 torchvision==0.16.0.post2 torchaudio==2.1.0.post2 intel-extension-for-pytorch==2.1.30.post0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ && \
|
||||
# Internal oneccl
|
||||
wget https://sourceforge.net/projects/oneccl-wks/files/2024.0.0.6.5-release/oneccl_wks_installer_2024.0.0.6.5.sh && \
|
||||
bash oneccl_wks_installer_2024.0.0.6.5.sh && \
|
||||
git clone https://github.com/intel/torch-ccl -b v2.1.300+xpu && \
|
||||
cd torch-ccl && \
|
||||
patch -p1 < /tmp/oneccl-binding.patch && \
|
||||
USE_SYSTEM_ONECCL=ON COMPUTE_BACKEND=dpcpp python setup.py install && \
|
||||
pip install /opt/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl && \
|
||||
# install Internal oneccl
|
||||
cd /opt && \
|
||||
wget https://sourceforge.net/projects/oneccl-wks/files/2025.0.0.6.6-release/oneccl_wks_installer_2025.0.0.6.6.sh && \
|
||||
bash oneccl_wks_installer_2025.0.0.6.6.sh && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends libfabric-dev wrk libaio-dev numactl && \
|
||||
# apt-get install -y intel-opencl-icd intel-level-zero-gpu=1.3.26241.33-647~22.04 level-zero level-zero-dev --allow-downgrades && \
|
||||
# Install compute runtime
|
||||
mkdir -p /tmp/neo && \
|
||||
cd /tmp/neo && \
|
||||
wget https://github.com/oneapi-src/level-zero/releases/download/v1.18.5/level-zero_1.18.5+u22.04_amd64.deb && \
|
||||
wget https://github.com/intel/intel-graphics-compiler/releases/download/igc-1.0.17791.9/intel-igc-core_1.0.17791.9_amd64.deb && \
|
||||
wget https://github.com/intel/intel-graphics-compiler/releases/download/igc-1.0.17791.9/intel-igc-opencl_1.0.17791.9_amd64.deb && \
|
||||
wget https://github.com/intel/compute-runtime/releases/download/24.39.31294.12/intel-level-zero-gpu_1.6.31294.12_amd64.deb && \
|
||||
wget https://github.com/intel/compute-runtime/releases/download/24.39.31294.12/intel-opencl-icd_24.39.31294.12_amd64.deb && \
|
||||
wget https://github.com/intel/compute-runtime/releases/download/24.39.31294.12/libigdgmm12_22.5.2_amd64.deb && \
|
||||
wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.5.6/intel-igc-core-2_2.5.6+18417_amd64.deb && \
|
||||
wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.5.6/intel-igc-opencl-2_2.5.6+18417_amd64.deb && \
|
||||
wget https://github.com/intel/compute-runtime/releases/download/24.52.32224.5/intel-level-zero-gpu-dbgsym_1.6.32224.5_amd64.ddeb && \
|
||||
wget https://github.com/intel/compute-runtime/releases/download/24.52.32224.5/intel-level-zero-gpu_1.6.32224.5_amd64.deb && \
|
||||
wget https://github.com/intel/compute-runtime/releases/download/24.52.32224.5/intel-opencl-icd-dbgsym_24.52.32224.5_amd64.ddeb && \
|
||||
wget https://github.com/intel/compute-runtime/releases/download/24.52.32224.5/intel-opencl-icd_24.52.32224.5_amd64.deb && \
|
||||
wget https://github.com/intel/compute-runtime/releases/download/24.52.32224.5/libigdgmm12_22.5.5_amd64.deb && \
|
||||
dpkg -i *.deb && \
|
||||
rm -rf /tmp/neo && \
|
||||
mkdir -p /llm && \
|
||||
cd /llm && \
|
||||
git clone -b 0.6.2 https://github.com/analytics-zoo/vllm.git /llm/vllm && \
|
||||
rm -rf /tmp/neo && \
|
||||
# Install vllm
|
||||
git clone -b 0.6.6-pre https://github.com/analytics-zoo/vllm.git /llm/vllm && \
|
||||
cd /llm/vllm && \
|
||||
pip install setuptools-scm && \
|
||||
pip install --upgrade cmake && \
|
||||
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm && \
|
||||
# pip install -r /llm/vllm/requirements-xpu.txt && \
|
||||
# VLLM_TARGET_DEVICE=xpu python setup.py install && \
|
||||
pip install mpi4py fastapi uvicorn openai && \
|
||||
pip install gradio==4.43.0 && \
|
||||
# pip install transformers==4.44.2 && \
|
||||
# patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch && \
|
||||
pip install ray && \
|
||||
patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch
|
||||
pip install ray
|
||||
|
||||
COPY ./vllm_online_benchmark.py /llm/
|
||||
COPY ./vllm_offline_inference.py /llm/
|
||||
|
|
@ -106,10 +136,6 @@ COPY ./payload-1024.lua /llm/
|
|||
COPY ./start-vllm-service.sh /llm/
|
||||
COPY ./benchmark_vllm_throughput.py /llm/
|
||||
COPY ./benchmark_vllm_latency.py /llm/
|
||||
COPY ./start-fastchat-service.sh /llm/
|
||||
COPY ./start-pp_serving-service.sh /llm/
|
||||
COPY ./start-lightweight_serving-service.sh /llm/
|
||||
|
||||
ENV LD_LIBRARY_PATH /usr/local/lib/python3.11/dist-packages/intel_extension_for_pytorch/lib/:/opt/intel/oneapi/tbb/2021.12/env/../lib/intel64/gcc4.8:/opt/intel/oneapi/mpi/2021.12/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/2021.12/lib:/opt/intel/oneapi/mkl/2024.1/lib:/opt/intel/oneapi/ippcp/2021.11/lib/:/opt/intel/oneapi/ipp/2021.11/lib:/opt/intel/oneapi/dpl/2022.5/lib:/opt/intel/oneapi/dnnl/2024.1/lib:/opt/intel/oneapi/debugger/2024.1/opt/debugger/lib:/opt/intel/oneapi/dal/2024.2/lib:/opt/intel/oneapi/compiler/2024.1/opt/oclfpga/host/linux64/lib:/opt/intel/oneapi/compiler/2024.1/opt/compiler/lib:/opt/intel/oneapi/compiler/2024.1/lib:/opt/intel/oneapi/ccl/2021.12/lib/
|
||||
|
||||
WORKDIR /llm/
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ To map the `xpu` into the container, you need to specify `--device=/dev/dri` whe
|
|||
An example could be:
|
||||
```bash
|
||||
#/bin/bash
|
||||
export DOCKER_IMAGE=intelanalytics/ipex-llm-serving-xpu:2.2.0-SNAPSHOT
|
||||
export DOCKER_IMAGE=intelanalytics/ipex-llm-serving-xpu:latest
|
||||
|
||||
sudo docker run -itd \
|
||||
--net=host \
|
||||
|
|
@ -59,86 +59,6 @@ To run Pipeline parallel serving using `IPEX-LLM` as backend, you can refer to t
|
|||
For convenience, we have included a file `/llm/start-pp_serving-service.sh` in the image.
|
||||
|
||||
|
||||
#### FastChat serving engine
|
||||
|
||||
To set up model serving using `IPEX-LLM` as backend using FastChat, you can refer to this [quickstart](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/fastchat_quickstart.html#) or follow these quick steps to deploy a demo.
|
||||
|
||||
##### Quick Setup for FastChat with IPEX-LLM
|
||||
|
||||
1. **Start the Docker Container**
|
||||
|
||||
Run the following command to launch a Docker container with device access:
|
||||
|
||||
```bash
|
||||
#/bin/bash
|
||||
export DOCKER_IMAGE=intelanalytics/ipex-llm-serving-xpu:latest
|
||||
|
||||
sudo docker run -itd \
|
||||
--net=host \
|
||||
--device=/dev/dri \
|
||||
--name=demo-container \
|
||||
# Example: map host model directory to container
|
||||
-v /LLM_MODELS/:/llm/models/ \
|
||||
--shm-size="16g" \
|
||||
# Optional: set proxy if needed
|
||||
-e http_proxy=... \
|
||||
-e https_proxy=... \
|
||||
-e no_proxy="127.0.0.1,localhost" \
|
||||
$DOCKER_IMAGE
|
||||
```
|
||||
|
||||
2. **Start the FastChat Service**
|
||||
|
||||
Enter the container and start the FastChat service:
|
||||
```bash
|
||||
#/bin/bash
|
||||
|
||||
# This command assumes that you have mapped the host model directory to the container
|
||||
# and the model directory is /llm/models/
|
||||
# we take Yi-1.5-34B as an example, and you can replace it with your own model
|
||||
|
||||
ps -ef | grep "fastchat" | awk '{print $2}' | xargs kill -9
|
||||
pip install -U gradio==4.43.0
|
||||
|
||||
# start controller
|
||||
python -m fastchat.serve.controller &
|
||||
|
||||
export USE_XETLA=OFF
|
||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
|
||||
|
||||
export TORCH_LLM_ALLREDUCE=0
|
||||
export CCL_DG2_ALLREDUCE=1
|
||||
# CCL needed environment variables
|
||||
export CCL_WORKER_COUNT=4
|
||||
# pin ccl worker to cores
|
||||
# export CCL_WORKER_AFFINITY=32,33,34,35
|
||||
export FI_PROVIDER=shm
|
||||
export CCL_ATL_TRANSPORT=ofi
|
||||
export CCL_ZE_IPC_EXCHANGE=sockets
|
||||
export CCL_ATL_SHM=1
|
||||
|
||||
source /opt/intel/1ccl-wks/setvars.sh
|
||||
|
||||
python -m ipex_llm.serving.fastchat.vllm_worker \
|
||||
--model-path /llm/models/Yi-1.5-34B \
|
||||
--device xpu \
|
||||
--enforce-eager \
|
||||
--disable-async-output-proc \
|
||||
--distributed-executor-backend ray \
|
||||
--dtype float16 \
|
||||
--load-in-low-bit fp8 \
|
||||
--tensor-parallel-size 4 \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--max-model-len 4096 \
|
||||
--max-num-batched-tokens 8000 &
|
||||
|
||||
sleep 120
|
||||
|
||||
python -m fastchat.serve.gradio_web_server &
|
||||
```
|
||||
|
||||
This quick setup allows you to deploy FastChat with IPEX-LLM efficiently.
|
||||
|
||||
#### vLLM serving engine
|
||||
|
||||
To run vLLM engine using `IPEX-LLM` as backend, you can refer to this [document](https://github.com/intel-analytics/ipex-llm/blob/main/docs/mddocs/DockerGuides/vllm_docker_quickstart.md).
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
|
@ -9,55 +10,26 @@ import numpy as np
|
|||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
speculative_model=args.speculative_model,
|
||||
num_speculative_tokens=args.num_speculative_tokens,
|
||||
speculative_draft_tensor_parallel_size=\
|
||||
args.speculative_draft_tensor_parallel_size,
|
||||
tokenizer=args.tokenizer,
|
||||
quantization=args.quantization,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
dtype=args.dtype,
|
||||
max_model_len=args.max_model_len,
|
||||
enforce_eager=args.enforce_eager,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
quantization_param_path=args.quantization_param_path,
|
||||
device=args.device,
|
||||
ray_workers_use_nsight=args.ray_workers_use_nsight,
|
||||
use_v2_block_manager=args.use_v2_block_manager,
|
||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||
download_dir=args.download_dir,
|
||||
block_size=args.block_size,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
load_format=args.load_format,
|
||||
distributed_executor_backend=args.distributed_executor_backend,
|
||||
otlp_traces_endpoint=args.otlp_traces_endpoint,
|
||||
enable_prefix_caching=args.enable_prefix_caching,
|
||||
load_in_low_bit=args.load_in_low_bit,
|
||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||
max_num_seqs=args.max_num_seqs,
|
||||
)
|
||||
llm = LLM(**dataclasses.asdict(engine_args), load_in_low_bit = args.load_in_low_bit)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
temperature=0.0 if args.use_beam_search else 1.0,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
use_beam_search=args.use_beam_search,
|
||||
ignore_eos=True,
|
||||
max_tokens=args.output_len,
|
||||
)
|
||||
|
|
@ -65,37 +37,26 @@ def main(args: argparse.Namespace):
|
|||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
size=(args.batch_size,
|
||||
args.input_len))
|
||||
dummy_inputs: List[PromptInputs] = [{
|
||||
dummy_prompts: List[PromptType] = [{
|
||||
"prompt_token_ids": batch
|
||||
} for batch in dummy_prompt_token_ids.tolist()]
|
||||
|
||||
def run_to_completion(profile_dir: Optional[str] = None):
|
||||
if profile_dir:
|
||||
if args.device == "xpu":
|
||||
with torch.autograd.profiler_legacy.profile(enabled=True, use_xpu=True) as p:
|
||||
llm.generate(dummy_inputs,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
print("Sort by CPU time total...")
|
||||
print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
|
||||
print("Sort by XPU time total...")
|
||||
print(p.key_averages().table(sort_by="self_xpu_time_total", row_limit=-1))
|
||||
else:
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
# torch.profiler.ProfilerActivity.XPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
str(profile_dir))) as p:
|
||||
llm.generate(dummy_inputs,
|
||||
llm.generate(dummy_prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
print(p.key_averages())
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm.generate(dummy_inputs,
|
||||
llm.generate(dummy_prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
end_time = time.perf_counter()
|
||||
|
|
@ -142,19 +103,6 @@ if __name__ == '__main__':
|
|||
parser = FlexibleArgumentParser(
|
||||
description='Benchmark the latency of processing a single batch of '
|
||||
'requests till completion.')
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||
parser.add_argument('--speculative-model', type=str, default=None)
|
||||
parser.add_argument('--num-speculative-tokens', type=int, default=None)
|
||||
parser.add_argument('--speculative-draft-tensor-parallel-size',
|
||||
'-spec-draft-tp',
|
||||
type=int,
|
||||
default=None)
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=[*QUANTIZATION_METHODS, None],
|
||||
default=None)
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||
parser.add_argument('--input-len', type=int, default=32)
|
||||
parser.add_argument('--output-len', type=int, default=128)
|
||||
parser.add_argument('--batch-size', type=int, default=8)
|
||||
|
|
@ -171,45 +119,12 @@ if __name__ == '__main__':
|
|||
type=int,
|
||||
default=30,
|
||||
help='Number of iterations to run.')
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
parser.add_argument(
|
||||
'--max-model-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Maximum length of a sequence (including prompt and output). '
|
||||
'If None, will be derived from the model.')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
"--load-in-low-bit",
|
||||
type=str,
|
||||
default='auto',
|
||||
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
||||
help='data type for model weights and activations. '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument('--enforce-eager',
|
||||
action='store_true',
|
||||
help='enforce eager mode and disable CUDA graph')
|
||||
parser.add_argument(
|
||||
'--kv-cache-dtype',
|
||||
type=str,
|
||||
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
||||
default="auto",
|
||||
help='Data type for kv cache storage. If "auto", will use model '
|
||||
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||||
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||||
parser.add_argument(
|
||||
'--quantization-param-path',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to the JSON file containing the KV cache scaling factors. '
|
||||
'This should generally be supplied, when KV cache dtype is FP8. '
|
||||
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
|
||||
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
|
||||
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
||||
'instead supported for common inference criteria.')
|
||||
choices=["sym_int4", "fp8", "fp8_e4m3", "fp16", "fp6"],
|
||||
default="sym_int4",
|
||||
help="Low-bit format quantization with IPEX-LLM")
|
||||
parser.add_argument(
|
||||
'--profile',
|
||||
action='store_true',
|
||||
|
|
@ -220,96 +135,12 @@ if __name__ == '__main__':
|
|||
default=None,
|
||||
help=('path to save the pytorch profiler output. Can be visualized '
|
||||
'with ui.perfetto.dev or Tensorboard.'))
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
|
||||
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
|
||||
'CPU.')
|
||||
parser.add_argument('--block-size',
|
||||
type=int,
|
||||
default=16,
|
||||
help='block size of key/value cache')
|
||||
parser.add_argument(
|
||||
'--enable-chunked-prefill',
|
||||
action='store_true',
|
||||
help='If True, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens')
|
||||
parser.add_argument("--enable-prefix-caching",
|
||||
action='store_true',
|
||||
help="Enable automatic prefix caching")
|
||||
parser.add_argument('--use-v2-block-manager', action='store_true')
|
||||
parser.add_argument(
|
||||
"--ray-workers-use-nsight",
|
||||
action='store_true',
|
||||
help="If specified, use nsight to profile ray workers",
|
||||
)
|
||||
parser.add_argument('--download-dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='directory to download and load the weights, '
|
||||
'default to the default cache dir of huggingface')
|
||||
parser.add_argument(
|
||||
'--output-json',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to save the latency results in JSON format.')
|
||||
parser.add_argument('--gpu-memory-utilization',
|
||||
type=float,
|
||||
default=0.9,
|
||||
help='the fraction of GPU memory to be used for '
|
||||
'the model executor, which can range from 0 to 1.'
|
||||
'If unspecified, will use the default value of 0.9.')
|
||||
parser.add_argument(
|
||||
'--load-format',
|
||||
type=str,
|
||||
default=EngineArgs.load_format,
|
||||
choices=[
|
||||
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
|
||||
'bitsandbytes'
|
||||
],
|
||||
help='The format of the model weights to load.\n\n'
|
||||
'* "auto" will try to load the weights in the safetensors format '
|
||||
'and fall back to the pytorch bin format if safetensors format '
|
||||
'is not available.\n'
|
||||
'* "pt" will load the weights in the pytorch bin format.\n'
|
||||
'* "safetensors" will load the weights in the safetensors format.\n'
|
||||
'* "npcache" will load the weights in pytorch format and store '
|
||||
'a numpy cache to speed up the loading.\n'
|
||||
'* "dummy" will initialize the weights with random values, '
|
||||
'which is mainly for profiling.\n'
|
||||
'* "tensorizer" will load the weights using tensorizer from '
|
||||
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
|
||||
'section for more information.\n'
|
||||
'* "bitsandbytes" will load the weights using bitsandbytes '
|
||||
'quantization.\n')
|
||||
parser.add_argument(
|
||||
'--distributed-executor-backend',
|
||||
choices=['ray', 'mp'],
|
||||
default=None,
|
||||
help='Backend to use for distributed serving. When more than 1 GPU '
|
||||
'is used, will be automatically set to "ray" if installed '
|
||||
'or "mp" (multiprocessing) otherwise.')
|
||||
parser.add_argument(
|
||||
'--otlp-traces-endpoint',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Target URL to which OpenTelemetry traces will be sent.')
|
||||
parser.add_argument(
|
||||
"--load-in-low-bit",
|
||||
type=str,
|
||||
choices=["sym_int4", "fp8", "fp8_e4m3", "fp16", "fp6"],
|
||||
default="sym_int4",
|
||||
help="Low-bit format quantization with IPEX-LLM")
|
||||
parser.add_argument('--max-num-batched-tokens',
|
||||
type=int,
|
||||
default=4096,
|
||||
help='maximum number of batched tokens per iteration')
|
||||
|
||||
parser.add_argument('--max-num-seqs',
|
||||
type=int,
|
||||
default=256,
|
||||
help='Maximum number of sequences per iteration.')
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
|
@ -4,27 +4,96 @@ import dataclasses
|
|||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
from functools import cache
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
from ipex_llm.vllm.xpu.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
# from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
def sample_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int],
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
@dataclasses.dataclass
|
||||
class SampleRequest:
|
||||
"""A class representing a single inference request for benchmarking.
|
||||
|
||||
Attributes:
|
||||
prompt: The input text prompt for the model.
|
||||
prompt_len: The length of the prompt in tokens.
|
||||
expected_output_len: The expected length of the output in tokens.
|
||||
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
|
||||
images).
|
||||
lora_request: Optional LoRARequest specifying the LoRA to use.
|
||||
"""
|
||||
prompt: str
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
multi_modal_data: Optional[MultiModalDataDict] = None
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
|
||||
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
|
||||
"""Prepend and append special tokens around the question to form a prompt.
|
||||
|
||||
Args:
|
||||
question: The input question text to wrap with special tokens
|
||||
model: The name of the model being used, to determine which special
|
||||
tokens to add
|
||||
|
||||
Returns:
|
||||
The formatted prompt string with appropriate special tokens for the
|
||||
model
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported model name is provided
|
||||
"""
|
||||
model = model.lower()
|
||||
if "pixtral" in model:
|
||||
return f"<s>[INST]{question}\n[IMG][/INST]"
|
||||
raise ValueError(f"Unsupported model {model}")
|
||||
|
||||
|
||||
@cache
|
||||
def lora_path_on_disk(lora_path: str) -> str:
|
||||
return get_adapter_absolute_path(lora_path)
|
||||
|
||||
|
||||
lora_tokenizer_cache: Dict[int, AnyTokenizer] = {}
|
||||
|
||||
|
||||
def get_random_lora_request(
|
||||
args: argparse.Namespace
|
||||
) -> Tuple[LoRARequest, Optional[AnyTokenizer]]:
|
||||
global lora_tokenizer_cache
|
||||
lora_id = random.randint(1, args.max_loras)
|
||||
lora_request = LoRARequest(lora_name=str(lora_id),
|
||||
lora_int_id=lora_id,
|
||||
lora_path=lora_path_on_disk(args.lora_path))
|
||||
if lora_id not in lora_tokenizer_cache:
|
||||
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
|
||||
return lora_request, lora_tokenizer_cache[lora_id]
|
||||
|
||||
|
||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
args: argparse.Namespace) -> List[SampleRequest]:
|
||||
|
||||
dataset_path: str = args.dataset
|
||||
num_requests: int = args.num_prompts
|
||||
fixed_output_len: Optional[int] = args.output_len
|
||||
model: str = args.model
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
|
|
@ -33,24 +102,46 @@ def sample_requests(
|
|||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
|
||||
# Shuffle the dataset.
|
||||
random.shuffle(dataset)
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||
for i in range(len(dataset)):
|
||||
filtered_dataset: List[SampleRequest] = []
|
||||
for data in tqdm(dataset,
|
||||
total=len(filtered_dataset),
|
||||
desc="sampling requests"):
|
||||
if len(filtered_dataset) == num_requests:
|
||||
break
|
||||
|
||||
# Only keep the first two turns of each conversation.
|
||||
prompt = data["conversations"][0]["value"]
|
||||
completion = data["conversations"][1]["value"]
|
||||
|
||||
multi_modal_data: Optional[MultiModalDataDict] = None
|
||||
if "image" in data:
|
||||
multi_modal_data = multi_modal_data or {}
|
||||
image_path = data["image"]
|
||||
# TODO(vllm-project/vllm/issues/9778): Support multiple images.
|
||||
assert isinstance(image_path,
|
||||
str), "Only support single image input"
|
||||
try:
|
||||
multi_modal_data["image"] = Image.open(image_path).convert(
|
||||
"RGB")
|
||||
except FileNotFoundError:
|
||||
# Ignore datapoint where asset is missing
|
||||
continue
|
||||
prompt = _get_prompt_for_image_model(question=prompt, model=model)
|
||||
|
||||
request_tokenizer = tokenizer
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
if args.enable_lora:
|
||||
lora_request, lora_tokenizer = get_random_lora_request(args)
|
||||
if lora_tokenizer:
|
||||
request_tokenizer = lora_tokenizer
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = dataset[i][0]
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
completion = dataset[i][1]
|
||||
completion_token_ids = tokenizer(completion).input_ids
|
||||
prompt_token_ids = request_tokenizer(prompt).input_ids
|
||||
completion_token_ids = request_tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = len(completion_token_ids
|
||||
) if fixed_output_len is None else fixed_output_len
|
||||
|
|
@ -60,104 +151,110 @@ def sample_requests(
|
|||
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||
filtered_dataset.append(
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=multi_modal_data,
|
||||
lora_request=lora_request))
|
||||
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def run_vllm(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
requests: List[SampleRequest],
|
||||
n: int,
|
||||
low_bit: str,
|
||||
engine_args: EngineArgs,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
|
||||
llm = LLM(**dataclasses.asdict(engine_args), load_in_low_bit=low_bit)
|
||||
llm = LLM(**dataclasses.asdict(engine_args), load_in_low_bit = args.load_in_low_bit)
|
||||
|
||||
# Add the requests to the engine.
|
||||
warm_prompt = "hi " * (1024 - 1)
|
||||
warm_requests = [(warm_prompt, 1024, 1024)
|
||||
for _ in range(8)]
|
||||
|
||||
prompts: List[str] = []
|
||||
prompts: List[TextPrompt] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
for prompt, _, output_len in warm_requests:
|
||||
prompts.append(prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
))
|
||||
llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: List[str] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
for prompt, _, output_len in requests:
|
||||
prompts.append(prompt)
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TextPrompt(prompt=request.prompt,
|
||||
multi_modal_data=request.multi_modal_data))
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
max_tokens=request.expected_output_len,
|
||||
))
|
||||
lora_requests: Optional[List[LoRARequest]] = None
|
||||
if engine_args.enable_lora:
|
||||
lora_requests = [request.lora_request for request in requests]
|
||||
|
||||
use_beam_search = False
|
||||
|
||||
if not use_beam_search:
|
||||
start = time.perf_counter()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_requests,
|
||||
use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
prompts = [prompt for prompt, _, _ in requests]
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
prompts = [request.prompt for request in requests]
|
||||
# output_len should be the same for all requests.
|
||||
output_len = requests[0][2]
|
||||
for prompt, input_len, _output_len in requests:
|
||||
assert _output_len == output_len
|
||||
for request in requests:
|
||||
assert request.expected_output_len == output_len
|
||||
start = time.perf_counter()
|
||||
llm.beam_search(prompts,
|
||||
llm.beam_search(
|
||||
prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True)
|
||||
ignore_eos=True,
|
||||
))
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
requests: List[SampleRequest],
|
||||
n: int,
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, disable_frontend_multiprocessing) as llm:
|
||||
engine_args, disable_frontend_multiprocessing, load_in_low_bit=load_in_low_bit) as llm:
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: List[str] = []
|
||||
prompts: List[TextPrompt] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
for prompt, _, output_len in requests:
|
||||
prompts.append(prompt)
|
||||
lora_requests: List[Optional[LoRARequest]] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TextPrompt(prompt=request.prompt,
|
||||
multi_modal_data=request.multi_modal_data))
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
max_tokens=request.expected_output_len,
|
||||
))
|
||||
lora_requests.append(request.lora_request)
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
||||
for i, (prompt, sp,
|
||||
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
|
||||
generator = llm.generate(prompt,
|
||||
sp,
|
||||
lora_request=lr,
|
||||
request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
|
|
@ -167,7 +264,7 @@ async def run_vllm_async(
|
|||
|
||||
|
||||
def run_hf(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
requests: List[SampleRequest],
|
||||
model: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
n: int,
|
||||
|
|
@ -225,14 +322,14 @@ def run_hf(
|
|||
|
||||
|
||||
def run_mii(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
requests: List[SampleRequest],
|
||||
model: str,
|
||||
tensor_parallel_size: int,
|
||||
output_len: int,
|
||||
) -> float:
|
||||
from mii import client, serve
|
||||
llm = serve(model, tensor_parallel=tensor_parallel_size)
|
||||
prompts = [prompt for prompt, _, _ in requests]
|
||||
prompts = [request.prompt for request in requests]
|
||||
|
||||
start = time.perf_counter()
|
||||
llm.generate(prompts, max_new_tokens=output_len)
|
||||
|
|
@ -250,23 +347,50 @@ def main(args: argparse.Namespace):
|
|||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
if args.dataset is None:
|
||||
vocab_size = tokenizer.vocab_size
|
||||
requests = []
|
||||
for _ in range(args.num_prompts):
|
||||
|
||||
request_tokenizer = tokenizer
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
if args.enable_lora:
|
||||
lora_request, lora_tokenizer = get_random_lora_request(args)
|
||||
if lora_tokenizer:
|
||||
request_tokenizer = lora_tokenizer
|
||||
|
||||
# Synthesize a prompt with the given input length.
|
||||
candidate_ids = [
|
||||
random.randint(0, vocab_size - 1)
|
||||
for _ in range(args.input_len)
|
||||
]
|
||||
# As tokenizer may add additional tokens like BOS, we need to try
|
||||
# different lengths to get the desired input length.
|
||||
for i in range(-10, 10):
|
||||
prompt = "hi " * (args.input_len + i)
|
||||
tokenized_prompt = tokenizer(prompt).input_ids
|
||||
if len(tokenized_prompt) == args.input_len:
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Failed to synthesize a prompt with {args.input_len} tokens.")
|
||||
requests = [(prompt, args.input_len, args.output_len)
|
||||
for _ in range(args.num_prompts)]
|
||||
else:
|
||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
||||
args.output_len)
|
||||
for _ in range(5): # Max attempts to correct
|
||||
candidate_prompt = request_tokenizer.decode(candidate_ids)
|
||||
tokenized_len = len(request_tokenizer.encode(candidate_prompt))
|
||||
|
||||
if tokenized_len == args.input_len:
|
||||
break
|
||||
|
||||
# Adjust length based on difference
|
||||
diff = args.input_len - tokenized_len
|
||||
if diff > 0:
|
||||
candidate_ids.extend([
|
||||
random.randint(100, vocab_size - 100)
|
||||
for _ in range(diff)
|
||||
])
|
||||
else:
|
||||
candidate_ids = candidate_ids[:diff]
|
||||
requests.append(
|
||||
SampleRequest(prompt=candidate_prompt,
|
||||
prompt_len=args.input_len,
|
||||
expected_output_len=args.output_len,
|
||||
lora_request=lora_request))
|
||||
else:
|
||||
requests = sample_requests(tokenizer, args)
|
||||
|
||||
is_multi_modal = any(request.multi_modal_data is not None
|
||||
for request in requests)
|
||||
if args.backend == "vllm":
|
||||
if args.async_engine:
|
||||
elapsed_time = uvloop.run(
|
||||
|
|
@ -275,9 +399,10 @@ def main(args: argparse.Namespace):
|
|||
args.n,
|
||||
AsyncEngineArgs.from_cli_args(args),
|
||||
args.disable_frontend_multiprocessing,
|
||||
args.load_in_low_bit,
|
||||
))
|
||||
else:
|
||||
elapsed_time = run_vllm(requests, args.n, args.load_in_low_bit,
|
||||
elapsed_time = run_vllm(requests, args.n,
|
||||
EngineArgs.from_cli_args(args))
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
|
|
@ -288,9 +413,15 @@ def main(args: argparse.Namespace):
|
|||
args.output_len)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
total_num_tokens = sum(prompt_len + output_len
|
||||
for _, prompt_len, output_len in requests)
|
||||
total_output_tokens = sum(output_len for _, _, output_len in requests)
|
||||
total_num_tokens = sum(request.prompt_len + request.expected_output_len
|
||||
for request in requests)
|
||||
total_output_tokens = sum(request.expected_output_len
|
||||
for request in requests)
|
||||
if is_multi_modal:
|
||||
print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
|
||||
"following metrics are not accurate because image tokens are not"
|
||||
" counted. See vllm-project/vllm/issues/9778 for details.")
|
||||
# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
|
||||
|
|
@ -317,7 +448,9 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset.")
|
||||
help="Path to the dataset. The dataset is expected to "
|
||||
"be a json in form of List[Dict[..., conversations: "
|
||||
"List[Dict[..., value: <prompt_or_response>]]]]")
|
||||
parser.add_argument("--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
|
|
@ -352,12 +485,21 @@ if __name__ == "__main__":
|
|||
action='store_true',
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.")
|
||||
# IPEX-LLM optimization
|
||||
parser.add_argument(
|
||||
"--load-in-low-bit",
|
||||
type=str,
|
||||
choices=["sym_int4", "fp8", "fp8_e4m3", "fp16", "fp6"],
|
||||
choices=["sym_int4", "woq_int4", "fp8", "fp8_e4m3", "fp16", "fp6"],
|
||||
default="sym_int4",
|
||||
help="Low-bit format quantization with IPEX-LLM")
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
|
|
@ -367,6 +509,8 @@ if __name__ == "__main__":
|
|||
assert args.output_len is not None
|
||||
else:
|
||||
assert args.input_len is None
|
||||
if args.enable_lora:
|
||||
assert args.lora_path is not None
|
||||
|
||||
if args.backend == "vllm":
|
||||
if args.hf_max_batch_size is not None:
|
||||
|
|
@ -376,6 +520,9 @@ if __name__ == "__main__":
|
|||
raise ValueError("HF max batch size is required for HF backend.")
|
||||
if args.quantization is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
if args.enable_lora is not None:
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM"
|
||||
" backend")
|
||||
elif args.backend == "mii":
|
||||
if args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
|
|
@ -388,5 +535,7 @@ if __name__ == "__main__":
|
|||
if args.tokenizer != args.model:
|
||||
raise ValueError("Tokenizer must be the same as the model for MII "
|
||||
"backend.")
|
||||
if args.enable_lora is not None:
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM"
|
||||
" backend")
|
||||
main(args)
|
||||
|
||||
|
|
|
|||
47
docker/llm/serving/xpu/docker/ccl_torch.patch
Normal file
47
docker/llm/serving/xpu/docker/ccl_torch.patch
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
diff --git a/setup.py b/setup.py
|
||||
index 67016cb..305be5c 100644
|
||||
--- a/setup.py
|
||||
+++ b/setup.py
|
||||
@@ -118,8 +118,8 @@ class BuildCMakeExt(BuildExtension):
|
||||
if compute_backend == 'dpcpp':
|
||||
runtime = 'dpcpp'
|
||||
build_options['COMPUTE_BACKEND'] = compute_backend
|
||||
- import intel_extension_for_pytorch
|
||||
- build_options['CMAKE_PREFIX_PATH'] += ";" + intel_extension_for_pytorch.cmake_prefix_path
|
||||
+ # import intel_extension_for_pytorch
|
||||
+ # build_options['CMAKE_PREFIX_PATH'] += ";" + intel_extension_for_pytorch.cmake_prefix_path
|
||||
if "DPCPP_GCC_INSTALL_DIR" in my_env:
|
||||
exist_cflags = "CFLAGS" in my_env
|
||||
cflags = ""
|
||||
diff --git a/src/gpu/CMakeLists.txt b/src/gpu/CMakeLists.txt
|
||||
index 5628f31..d1589c4 100644
|
||||
--- a/src/gpu/CMakeLists.txt
|
||||
+++ b/src/gpu/CMakeLists.txt
|
||||
@@ -1,4 +1,4 @@
|
||||
-find_package(IPEX REQUIRED)
|
||||
+# find_package(IPEX REQUIRED)
|
||||
|
||||
set(CCL_DPCPP_SRCS dpcpp_ccl.cpp ze_exception.hpp allreduce.h sycl_misc.hpp runtime.hpp cxxopts.hpp)
|
||||
|
||||
@@ -9,7 +9,7 @@ add_library(oneccl_bindings_for_pytorch_xpu SHARED ${CCL_DPCPP_SRCS})
|
||||
|
||||
target_link_libraries(oneccl_bindings_for_pytorch_xpu PUBLIC ${DEPENDS_LIB})
|
||||
target_link_libraries(oneccl_bindings_for_pytorch_xpu PUBLIC oneccl_bindings_for_pytorch)
|
||||
-target_link_libraries(oneccl_bindings_for_pytorch_xpu PUBLIC intel-ext-pt-gpu)
|
||||
+# target_link_libraries(oneccl_bindings_for_pytorch_xpu PUBLIC intel-ext-pt-gpu)
|
||||
|
||||
foreach(RPATH ${CMAKE_INSTALL_RPATH})
|
||||
set_target_properties(oneccl_bindings_for_pytorch_xpu PROPERTIES LINK_FLAGS "-Wl,-rpath,${RPATH}")
|
||||
diff --git a/src/gpu/dpcpp_ccl.cpp b/src/gpu/dpcpp_ccl.cpp
|
||||
index 1631b85..0945031 100644
|
||||
--- a/src/gpu/dpcpp_ccl.cpp
|
||||
+++ b/src/gpu/dpcpp_ccl.cpp
|
||||
@@ -32,7 +32,7 @@
|
||||
#include <ATen/record_function.h>
|
||||
#include <ProcessGroupCCL.hpp>
|
||||
#include <dispatch_stub.h>
|
||||
-#include <ipex.h>
|
||||
+// #include <ipex.h>
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
//#include "allreduce.h"
|
||||
|
|
@ -1,209 +0,0 @@
|
|||
--- a/gradio_web_server.py
|
||||
+++ b/gradio_web_server_new.py
|
||||
@@ -9,8 +9,10 @@ import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
+import pandas as pd
|
||||
import time
|
||||
import uuid
|
||||
+import numpy as np
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
@@ -241,7 +243,7 @@ def clear_history(request: gr.Request):
|
||||
ip = get_ip(request)
|
||||
logger.info(f"clear_history. ip: {ip}")
|
||||
state = None
|
||||
- return (state, [], "", None) + (disable_btn,) * 5
|
||||
+ return (state, [], "", None, "", "", "", "") + (disable_btn,) * 5
|
||||
|
||||
|
||||
def get_ip(request: gr.Request):
|
||||
@@ -354,6 +356,18 @@ def is_limit_reached(model_name, ip):
|
||||
return None
|
||||
|
||||
|
||||
+def handle_latency_metrics(first_token_time, next_token_time):
|
||||
+ # next token time is a numpy array...
|
||||
+ # first token time might be None
|
||||
+ first_token_latency = "None"
|
||||
+ next_token_latency = "None"
|
||||
+ if first_token_time is not None:
|
||||
+ first_token_latency = f"{first_token_time * 1000 :.2f} ms"
|
||||
+ if next_token_time.size > 0:
|
||||
+ next_token_latency = f"{np.mean(next_token_time) * 1000 :.2f} ms"
|
||||
+ return first_token_latency, next_token_latency
|
||||
+
|
||||
+
|
||||
def bot_response(
|
||||
state,
|
||||
temperature,
|
||||
@@ -372,7 +386,7 @@ def bot_response(
|
||||
if state.skip_next:
|
||||
# This generate call is skipped due to invalid inputs
|
||||
state.skip_next = False
|
||||
- yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
+ yield (state, state.to_gradio_chatbot(), "None", "None", "None", "None") + (no_change_btn,) * 5
|
||||
return
|
||||
|
||||
if apply_rate_limit:
|
||||
@@ -381,7 +395,7 @@ def bot_response(
|
||||
error_msg = RATE_LIMIT_MSG + "\n\n" + ret["reason"]
|
||||
logger.info(f"rate limit reached. ip: {ip}. error_msg: {ret['reason']}")
|
||||
state.conv.update_last_message(error_msg)
|
||||
- yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
+ yield (state, state.to_gradio_chatbot(), "None", "None", "None", "None") + (no_change_btn,) * 5
|
||||
return
|
||||
|
||||
conv, model_name = state.conv, state.model_name
|
||||
@@ -404,6 +418,10 @@ def bot_response(
|
||||
yield (
|
||||
state,
|
||||
state.to_gradio_chatbot(),
|
||||
+ "None",
|
||||
+ "None",
|
||||
+ "None",
|
||||
+ "None",
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
@@ -444,18 +462,32 @@ def bot_response(
|
||||
)
|
||||
|
||||
conv.update_last_message("▌")
|
||||
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||
+ # We probably need to change this method
|
||||
+ yield (state, state.to_gradio_chatbot(), "None", "None", "None", "None") + (disable_btn,) * 5
|
||||
+ prompt_tokens = 0
|
||||
+ generated_tokens = 0
|
||||
+ first_token_latency = None
|
||||
+ next_token_latencies = np.array([])
|
||||
+ start_time = time.time()
|
||||
|
||||
try:
|
||||
for i, data in enumerate(stream_iter):
|
||||
if data["error_code"] == 0:
|
||||
+ prompt_tokens = data["usage"]["prompt_tokens"]
|
||||
+ generated_tokens = data["usage"]["completion_tokens"]
|
||||
output = data["text"].strip()
|
||||
conv.update_last_message(output + "▌")
|
||||
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||
+ if first_token_latency is None:
|
||||
+ first_token_latency = time.time() - start_time
|
||||
+ else:
|
||||
+ next_token_latencies = np.append(next_token_latencies, time.time() - start_time)
|
||||
+ start_time = time.time()
|
||||
+ first_latency, next_latency = handle_latency_metrics(first_token_latency, next_token_latencies)
|
||||
+ yield (state, state.to_gradio_chatbot(), prompt_tokens, generated_tokens, first_latency, next_latency) + (disable_btn,) * 5
|
||||
else:
|
||||
output = data["text"] + f"\n\n(error_code: {data['error_code']})"
|
||||
conv.update_last_message(output)
|
||||
- yield (state, state.to_gradio_chatbot()) + (
|
||||
+ yield (state, state.to_gradio_chatbot(), "None", "None", "None", "None") + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
@@ -465,13 +497,14 @@ def bot_response(
|
||||
return
|
||||
output = data["text"].strip()
|
||||
conv.update_last_message(output)
|
||||
- yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
+ first_latency, next_latency = handle_latency_metrics(first_token_latency, next_token_latencies)
|
||||
+ yield (state, state.to_gradio_chatbot(), prompt_tokens, generated_tokens, first_latency, next_latency) + (enable_btn,) * 5
|
||||
except requests.exceptions.RequestException as e:
|
||||
conv.update_last_message(
|
||||
f"{SERVER_ERROR_MSG}\n\n"
|
||||
f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
|
||||
)
|
||||
- yield (state, state.to_gradio_chatbot()) + (
|
||||
+ yield (state, state.to_gradio_chatbot(), "None", "None", "None", "None") + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
@@ -484,7 +517,7 @@ def bot_response(
|
||||
f"{SERVER_ERROR_MSG}\n\n"
|
||||
f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
|
||||
)
|
||||
- yield (state, state.to_gradio_chatbot()) + (
|
||||
+ yield (state, state.to_gradio_chatbot(), "None", "None", "None", "None") + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
@@ -646,7 +679,8 @@ def build_single_model_ui(models, add_promotion_links=False):
|
||||
)
|
||||
|
||||
notice_markdown = f"""
|
||||
-# 🏔️ Chat with Open Large Language Models
|
||||
+# 🏔️ ChatBot based Xeon-W & Arc GPUs
|
||||
+### Deployed with IPEX-LLM
|
||||
{promotion}
|
||||
"""
|
||||
|
||||
@@ -717,6 +751,22 @@ def build_single_model_ui(models, add_promotion_links=False):
|
||||
label="Max output tokens",
|
||||
)
|
||||
|
||||
+ with gr.Row():
|
||||
+ with gr.Column():
|
||||
+ gr.Markdown("### Performance Metrics")
|
||||
+ prompt_token = gr.Label(
|
||||
+ label="Prompt token length:",
|
||||
+ )
|
||||
+ next_token = gr.Label(
|
||||
+ label="Generated token length:",
|
||||
+ )
|
||||
+ first_token_latency = gr.Label(
|
||||
+ label="First token Latency:",
|
||||
+ )
|
||||
+ next_token_latency = gr.Label(
|
||||
+ label="Next token Latency:",
|
||||
+ )
|
||||
+
|
||||
if add_promotion_links:
|
||||
gr.Markdown(acknowledgment_md, elem_id="ack_markdown")
|
||||
|
||||
@@ -743,9 +793,9 @@ def build_single_model_ui(models, add_promotion_links=False):
|
||||
).then(
|
||||
bot_response,
|
||||
[state, temperature, top_p, max_output_tokens],
|
||||
- [state, chatbot] + btn_list,
|
||||
+ [state, chatbot, prompt_token, next_token, first_token_latency, next_token_latency] + btn_list,
|
||||
)
|
||||
- clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)
|
||||
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox, prompt_token, next_token, first_token_latency, next_token_latency] + btn_list)
|
||||
|
||||
model_selector.change(
|
||||
clear_history, None, [state, chatbot, textbox, imagebox] + btn_list
|
||||
@@ -758,7 +808,7 @@ def build_single_model_ui(models, add_promotion_links=False):
|
||||
).then(
|
||||
bot_response,
|
||||
[state, temperature, top_p, max_output_tokens],
|
||||
- [state, chatbot] + btn_list,
|
||||
+ [state, chatbot, prompt_token, next_token, first_token_latency, next_token_latency] + btn_list,
|
||||
)
|
||||
send_btn.click(
|
||||
add_text,
|
||||
@@ -767,7 +817,7 @@ def build_single_model_ui(models, add_promotion_links=False):
|
||||
).then(
|
||||
bot_response,
|
||||
[state, temperature, top_p, max_output_tokens],
|
||||
- [state, chatbot] + btn_list,
|
||||
+ [state, chatbot, prompt_token, next_token, first_token_latency, next_token_latency] + btn_list,
|
||||
)
|
||||
|
||||
return [state, model_selector]
|
||||
@@ -775,7 +825,7 @@ def build_single_model_ui(models, add_promotion_links=False):
|
||||
|
||||
def build_demo(models):
|
||||
with gr.Blocks(
|
||||
- title="Chat with Open Large Language Models",
|
||||
+ title="ChatBot based Xeon-W & Arc GPUs",
|
||||
theme=gr.themes.Default(),
|
||||
css=block_css,
|
||||
) as demo:
|
||||
@@ -885,3 +935,4 @@ if __name__ == "__main__":
|
||||
auth=auth,
|
||||
root_path=args.gradio_root_path,
|
||||
)
|
||||
+
|
||||
|
|
@ -1,125 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 [-w --worker <model_worker|vllm_worker>] [--help]"
|
||||
echo "--help: Print help message."
|
||||
echo "The following environment variables can be set."
|
||||
echo "MODEL_PATH (default: empty)."
|
||||
echo "LOW_BIT_FORMAT (default: sym_int4)"
|
||||
echo "CONTROLLER_HOST (default: localhost)."
|
||||
echo "CONTROLLER_PORT (default: 21001)."
|
||||
echo "WORKER_HOST (default: localhost)."
|
||||
echo "WORKER_PORT (default: 21002)."
|
||||
echo "API_HOST (default: localhost)."
|
||||
echo "API_PORT (default: 8000)."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Default values
|
||||
controller_host="localhost"
|
||||
controller_port="21001"
|
||||
worker_host="localhost"
|
||||
worker_port="21002"
|
||||
api_host="localhost"
|
||||
api_port="8000"
|
||||
model_path=""
|
||||
mode=""
|
||||
dispatch_method="shortest_queue" # shortest_queue or lottery
|
||||
stream_interval=1
|
||||
worker_type="model_worker"
|
||||
low_bit_format="sym_int4"
|
||||
|
||||
# We do not have any arguments, just run bash
|
||||
|
||||
# Parse command-line options
|
||||
options=$(getopt -o "hw:" --long "help,worker:" -n "$0" -- "$@")
|
||||
if [ $? != 0 ]; then
|
||||
usage
|
||||
fi
|
||||
eval set -- "$options"
|
||||
|
||||
while true; do
|
||||
case "$1" in
|
||||
-w|--worker)
|
||||
worker_type="$2"
|
||||
[[ $worker_type == "model_worker" || $worker_type == "vllm_worker" ]] || usage
|
||||
shift 2
|
||||
;;
|
||||
-h|--help)
|
||||
usage
|
||||
;;
|
||||
--)
|
||||
shift
|
||||
break
|
||||
;;
|
||||
*)
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ "$worker_type" == "model_worker" ]; then
|
||||
worker_type="ipex_llm.serving.fastchat.ipex_llm_worker"
|
||||
elif [ "$worker_type" == "vllm_worker" ]; then
|
||||
worker_type="ipex_llm.serving.fastchat.vllm_worker"
|
||||
fi
|
||||
|
||||
if [[ -n $CONTROLLER_HOST ]]; then
|
||||
controller_host=$CONTROLLER_HOST
|
||||
fi
|
||||
|
||||
if [[ -n $CONTROLLER_PORT ]]; then
|
||||
controller_port=$CONTROLLER_PORT
|
||||
fi
|
||||
|
||||
if [[ -n $LOW_BIT_FORMAT ]]; then
|
||||
low_bit_format=$LOW_BIT_FORMAT
|
||||
fi
|
||||
|
||||
if [[ -n $WORKER_HOST ]]; then
|
||||
worker_host=$WORKER_HOST
|
||||
fi
|
||||
|
||||
if [[ -n $WORKER_PORT ]]; then
|
||||
worker_port=$WORKER_PORT
|
||||
fi
|
||||
|
||||
if [[ -n $MODEL_PATH ]]; then
|
||||
model_path=$MODEL_PATH
|
||||
fi
|
||||
|
||||
if [[ -n $API_HOST ]]; then
|
||||
api_host=$API_HOST
|
||||
fi
|
||||
|
||||
if [[ -n $API_PORT ]]; then
|
||||
api_port=$API_PORT
|
||||
fi
|
||||
|
||||
if [[ -n $DISPATCH_METHOD ]]; then
|
||||
dispatch_method=$DISPATCH_METHOD
|
||||
fi
|
||||
|
||||
if [[ -n $STREAM_INTERVAL ]]; then
|
||||
stream_interval=$STREAM_INTERVAL
|
||||
fi
|
||||
|
||||
controller_address="http://$controller_host:$controller_port"
|
||||
echo "Controller address: $controller_address"
|
||||
python3 -m fastchat.serve.controller --host $controller_host --port $controller_port --dispatch-method $dispatch_method &
|
||||
|
||||
worker_address="http://$worker_host:$worker_port"
|
||||
echo "Worker type: $worker_type"
|
||||
echo "Worker address: $worker_address"
|
||||
|
||||
if [ "$worker_type" == "ipex_llm.serving.fastchat.ipex_llm_worker" ]; then
|
||||
python3 -m "$worker_type" --model-path $model_path --device xpu --low-bit $low_bit_format --host $worker_host --port $worker_port --worker-address $worker_address --controller-address $controller_address --stream-interval $stream_interval &
|
||||
elif [ "$worker_type" == "ipex_llm.serving.fastchat.vllm_worker" ]; then
|
||||
python3 -m "$worker_type" --model-path $model_path --device xpu --load-in-low-bit $low_bit_format --host $worker_host --port $worker_port --worker-address $worker_address --controller-address $controller_address --enforce-eager --gpu-memory-utilization 0.85 &
|
||||
fi
|
||||
|
||||
sleep 10
|
||||
|
||||
api_address="http://$api_host:$api_port"
|
||||
echo "OpenAI API address: $api_address"
|
||||
python3 -m fastchat.serve.openai_api_server --host $api_host --port $api_port --controller-address $controller_address
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
cd /llm/lightweight_serving
|
||||
model_path="/llm/models/Llama-2-7b-chat-hf"
|
||||
low_bit="sym_int4"
|
||||
python lightweight_serving.py --repo-id-or-model-path $model_path --low-bit $low_bit
|
||||
|
|
@ -667,7 +667,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
out_features,
|
||||
mp_group,
|
||||
None,
|
||||
None,
|
||||
optimize_lm_head,
|
||||
None
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,18 +13,28 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Dict, Optional
|
||||
from vllm.logger import init_logger
|
||||
from typing import Dict, Optional, Any, Union, Type
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.utils import Counter
|
||||
from vllm.config import EngineConfig
|
||||
from vllm.config import VllmConfig
|
||||
from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.engine.metrics import StatLoggerBase
|
||||
from vllm.engine.multiprocessing.engine import MQLLMEngine
|
||||
import signal
|
||||
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
|
||||
TaskOption)
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
||||
from vllm import envs
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
import os
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
||||
|
|
@ -35,7 +45,7 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
|||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
engine_config: Optional[EngineConfig] = None,
|
||||
engine_config: Optional[VllmConfig] = None,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
|
|
@ -49,6 +59,27 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
|||
usage_context=usage_context, stat_loggers=stat_loggers)
|
||||
|
||||
|
||||
class IPEXLLMAsyncV1Engine(AsyncLLM):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
engine_config: Optional[VllmConfig] = None,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None, # noqa
|
||||
) -> "AsyncLLM":
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context, stat_loggers=stat_loggers)
|
||||
|
||||
|
||||
class IPEXLLMClass(LLM):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -57,6 +88,7 @@ class IPEXLLMClass(LLM):
|
|||
tokenizer_mode: str = "auto",
|
||||
skip_tokenizer_init: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
allowed_local_media_path: str = "",
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
|
|
@ -64,28 +96,48 @@ class IPEXLLMClass(LLM):
|
|||
tokenizer_revision: Optional[str] = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
swap_space: float = 4,
|
||||
cpu_offload_gb: float = 0,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
disable_async_output_proc: bool = True,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]]=None,
|
||||
# After positional args are removed, move this right below `model`
|
||||
task: TaskOption = "auto",
|
||||
override_pooler_config: Optional[PoolerConfig] = None,
|
||||
compilation_config: Optional[Union[int, Dict[str, Any]]]=None,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
'''
|
||||
LLM constructor.
|
||||
|
||||
Note: if enforce_eager is unset (enforce_eager is None)
|
||||
it defaults to False.
|
||||
'''
|
||||
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
removed_vision_keys = ("image_token_id", "image_feature_size",
|
||||
"image_input_shape", "image_input_type")
|
||||
if any(k in kwargs for k in removed_vision_keys):
|
||||
raise TypeError( # noqa
|
||||
"There is no need to pass vision-related arguments anymore.")
|
||||
|
||||
if compilation_config is not None:
|
||||
if isinstance(compilation_config, (int, dict)):
|
||||
compilation_config_instance = CompilationConfig.from_cli(
|
||||
str(compilation_config))
|
||||
else:
|
||||
compilation_config_instance = compilation_config
|
||||
else:
|
||||
compilation_config_instance = None
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
task=task,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
trust_remote_code=trust_remote_code,
|
||||
allowed_local_media_path=allowed_local_media_path,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
|
|
@ -96,16 +148,53 @@ class IPEXLLMClass(LLM):
|
|||
swap_space=swap_space,
|
||||
cpu_offload_gb=cpu_offload_gb,
|
||||
enforce_eager=enforce_eager,
|
||||
max_context_len_to_capture=max_context_len_to_capture,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
hf_overrides=hf_overrides,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
override_pooler_config=override_pooler_config,
|
||||
compilation_config=compilation_config_instance,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_engine = IPEXLLMLLMEngine.from_engine_args(
|
||||
# Logic to switch between engines is done at runtime instead of import
|
||||
# to avoid import order issues
|
||||
self.engine_class = self.get_engine_class()
|
||||
self.llm_engine = self.engine_class.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.LLM_CLASS,
|
||||
load_in_low_bit=load_in_low_bit)
|
||||
|
||||
self.request_counter = Counter()
|
||||
|
||||
@staticmethod
|
||||
def get_engine_class() -> Type[LLMEngine]:
|
||||
if envs.VLLM_USE_V1:
|
||||
return IPEXLLMLLMV1Engine
|
||||
return IPEXLLMLLMEngine
|
||||
|
||||
|
||||
class IPEXLLMLLMV1Engine(V1LLMEngine):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: EngineArgs,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
|
||||
enable_multiprocessing: bool = False,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
return super().from_engine_args(engine_args,
|
||||
usage_context,
|
||||
stat_loggers,
|
||||
enable_multiprocessing)
|
||||
|
||||
|
||||
class IPEXLLMLLMEngine(LLMEngine):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -134,12 +223,13 @@ class IPEXLLMMQLLMEngine(MQLLMEngine):
|
|||
|
||||
|
||||
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
|
||||
ipc_path: str, load_in_low_bit: str):
|
||||
ipc_path: str, load_in_low_bit: str, engine_alive):
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm
|
||||
raise KeyboardInterrupt("MQLLMEngine terminated") # noqa
|
||||
|
||||
try:
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args,
|
||||
|
|
@ -147,3 +237,10 @@ def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
|
|||
ipc_path=ipc_path,
|
||||
load_in_low_bit=load_in_low_bit)
|
||||
engine.start()
|
||||
except BaseException as e:
|
||||
logger.exception(e)
|
||||
engine_alive.value = False
|
||||
raise e # noqa
|
||||
|
||||
if os.getenv("VLLM_USE_V1"):
|
||||
IPEXLLMAsyncLLMEngine = IPEXLLMAsyncV1Engine
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import atexit
|
||||
import importlib
|
||||
import inspect
|
||||
import multiprocessing
|
||||
|
|
@ -7,11 +8,12 @@ import re
|
|||
import signal
|
||||
import socket
|
||||
import tempfile
|
||||
import uuid
|
||||
from argparse import Namespace
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Set
|
||||
from typing import AsyncIterator, Optional, Set, Tuple
|
||||
|
||||
import uvloop
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
|
|
@ -29,9 +31,13 @@ from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
|
|||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from ipex_llm.vllm.xpu.engine import run_mp_engine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
|
||||
# from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
|
|
@ -41,8 +47,12 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
ScoreRequest, ScoreResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
UnloadLoraAdapterRequest)
|
||||
|
|
@ -50,12 +60,20 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.entrypoints.utils import with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
|
||||
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address, set_ulimit)
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
|
@ -111,7 +129,7 @@ async def build_async_engine_client(
|
|||
async def build_async_engine_client_from_engine_args(
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
load_in_low_bit: str = 'sym_int4',
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
) -> AsyncIterator[EngineClient]:
|
||||
"""
|
||||
Create EngineClient, either:
|
||||
|
|
@ -124,25 +142,19 @@ async def build_async_engine_client_from_engine_args(
|
|||
# Fall back
|
||||
# TODO: fill out feature matrix.
|
||||
if (MQLLMEngineClient.is_unsupported_config(engine_args)
|
||||
or disable_frontend_multiprocessing):
|
||||
engine_config = engine_args.create_engine_config()
|
||||
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
||||
"uses_ray", False)
|
||||
|
||||
build_engine = partial(AsyncLLMEngine.from_engine_args,
|
||||
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
|
||||
engine_client: Optional[EngineClient] = None
|
||||
try:
|
||||
# When starting this, we are actually starting with the V1Engine
|
||||
# Here we are doing a classification, we will need to do this in IPEX-LLM
|
||||
engine_client = AsyncLLMEngine.from_engine_args(
|
||||
engine_args=engine_args,
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
engine_config=engine_config,
|
||||
usage_context=UsageContext.OPENAI_API_SERVER)
|
||||
if uses_ray:
|
||||
# Must run in main thread with ray for its signal handlers to work
|
||||
engine_client = build_engine()
|
||||
else:
|
||||
engine_client = await asyncio.get_running_loop().run_in_executor(
|
||||
None, build_engine)
|
||||
|
||||
usage_context=UsageContext.OPENAI_API_SERVER,
|
||||
load_in_low_bit=load_in_low_bit)
|
||||
yield engine_client
|
||||
return
|
||||
finally:
|
||||
if engine_client and hasattr(engine_client, "shutdown"):
|
||||
engine_client.shutdown()
|
||||
|
||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||
else:
|
||||
|
|
@ -163,7 +175,7 @@ async def build_async_engine_client_from_engine_args(
|
|||
|
||||
# Select random path for IPC.
|
||||
ipc_path = get_open_zmq_ipc_path()
|
||||
logger.info("Multiprocessing frontend to use %s for IPC Path.",
|
||||
logger.debug("Multiprocessing frontend to use %s for IPC Path.",
|
||||
ipc_path)
|
||||
|
||||
# Start RPCServer in separate process (holds the LLMEngine).
|
||||
|
|
@ -171,37 +183,52 @@ async def build_async_engine_client_from_engine_args(
|
|||
# so we need to spawn a new process
|
||||
context = multiprocessing.get_context("spawn")
|
||||
|
||||
# The Process can raise an exception during startup, which may
|
||||
# not actually result in an exitcode being reported. As a result
|
||||
# we use a shared variable to communicate the information.
|
||||
engine_alive = multiprocessing.Value('b', True, lock=False)
|
||||
engine_process = context.Process(target=run_mp_engine,
|
||||
args=(engine_args,
|
||||
UsageContext.OPENAI_API_SERVER,
|
||||
ipc_path,
|
||||
load_in_low_bit))
|
||||
ipc_path, load_in_low_bit, engine_alive))
|
||||
engine_process.start()
|
||||
logger.info("Started engine process with PID %d", engine_process.pid)
|
||||
engine_pid = engine_process.pid
|
||||
assert engine_pid is not None, "Engine process failed to start."
|
||||
logger.info("Started engine process with PID %d", engine_pid)
|
||||
|
||||
def _cleanup_ipc_path():
|
||||
socket_path = ipc_path.replace("ipc://", "")
|
||||
if os.path.exists(socket_path):
|
||||
os.remove(socket_path)
|
||||
|
||||
# Ensure we clean up the local IPC socket file on exit.
|
||||
atexit.register(_cleanup_ipc_path)
|
||||
|
||||
# Build RPCClient, which conforms to EngineClient Protocol.
|
||||
# NOTE: Actually, this is not true yet. We still need to support
|
||||
# embedding models via RPC (see TODO above)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
|
||||
|
||||
build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
|
||||
engine_pid)
|
||||
mq_engine_client = await asyncio.get_running_loop().run_in_executor(
|
||||
None, build_client)
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
await mp_engine_client.setup()
|
||||
await mq_engine_client.setup()
|
||||
break
|
||||
except TimeoutError:
|
||||
if not engine_process.is_alive():
|
||||
if (not engine_process.is_alive()
|
||||
or not engine_alive.value):
|
||||
raise RuntimeError(
|
||||
"Engine process failed to start") from None
|
||||
"Engine process failed to start. See stack "
|
||||
"trace for the root cause.") from None
|
||||
|
||||
yield mp_engine_client # type: ignore[misc]
|
||||
yield mq_engine_client # type: ignore[misc]
|
||||
finally:
|
||||
# Ensure rpc server process was terminated
|
||||
engine_process.terminate()
|
||||
|
||||
# Close all open connections to the backend
|
||||
mp_engine_client.close()
|
||||
mq_engine_client.close()
|
||||
|
||||
# Wait for engine process to join
|
||||
engine_process.join(4)
|
||||
|
|
@ -230,7 +257,7 @@ def mount_metrics(app: FastAPI):
|
|||
|
||||
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
||||
if prometheus_multiproc_dir_path is not None:
|
||||
logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
|
||||
logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
|
||||
prometheus_multiproc_dir_path)
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
|
|
@ -246,22 +273,35 @@ def mount_metrics(app: FastAPI):
|
|||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
def chat(request: Request) -> OpenAIServingChat:
|
||||
def base(request: Request) -> OpenAIServing:
|
||||
# Reuse the existing instance
|
||||
return tokenization(request)
|
||||
|
||||
|
||||
def chat(request: Request) -> Optional[OpenAIServingChat]:
|
||||
return request.app.state.openai_serving_chat
|
||||
|
||||
|
||||
def completion(request: Request) -> OpenAIServingCompletion:
|
||||
def completion(request: Request) -> Optional[OpenAIServingCompletion]:
|
||||
return request.app.state.openai_serving_completion
|
||||
|
||||
|
||||
def pooling(request: Request) -> Optional[OpenAIServingPooling]:
|
||||
return request.app.state.openai_serving_pooling
|
||||
|
||||
|
||||
def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def score(request: Request) -> Optional[OpenAIServingScores]:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def embedding(request: Request) -> OpenAIServingEmbedding:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
|
@ -274,8 +314,11 @@ async def health(raw_request: Request) -> Response:
|
|||
|
||||
|
||||
@router.post("/tokenize")
|
||||
@with_cancellation
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_tokenize(request)
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
generator = await handler.create_tokenize(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
|
|
@ -286,8 +329,11 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
|
|||
|
||||
|
||||
@router.post("/detokenize")
|
||||
@with_cancellation
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_detokenize(request)
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
generator = await handler.create_detokenize(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
|
|
@ -299,7 +345,9 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
|||
|
||||
@router.get("/v1/models")
|
||||
async def show_available_models(raw_request: Request):
|
||||
models = await completion(raw_request).show_available_models()
|
||||
handler = base(raw_request)
|
||||
|
||||
models = await handler.show_available_models()
|
||||
return JSONResponse(content=models.model_dump())
|
||||
|
||||
|
||||
|
|
@ -310,11 +358,15 @@ async def show_version():
|
|||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@with_cancellation
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
raw_request: Request):
|
||||
handler = chat(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Chat Completions API")
|
||||
|
||||
generator = await chat(raw_request).create_chat_completion(
|
||||
request, raw_request)
|
||||
generator = await handler.create_chat_completion(request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
|
|
@ -327,9 +379,14 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||
|
||||
|
||||
@router.post("/v1/completions")
|
||||
@with_cancellation
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
generator = await completion(raw_request).create_completion(
|
||||
request, raw_request)
|
||||
handler = completion(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Completions API")
|
||||
|
||||
generator = await handler.create_completion(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
|
|
@ -340,9 +397,40 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
@with_cancellation
|
||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
generator = await embedding(raw_request).create_embedding(
|
||||
request, raw_request)
|
||||
handler = embedding(raw_request)
|
||||
if handler is None:
|
||||
fallback_handler = pooling(raw_request)
|
||||
if fallback_handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Embeddings API")
|
||||
|
||||
logger.warning(
|
||||
"Embeddings API will become exclusive to embedding models "
|
||||
"in a future release. To return the hidden states directly, "
|
||||
"use the Pooling API (`/pooling`) instead.")
|
||||
|
||||
res = await fallback_handler.create_pooling(request, raw_request)
|
||||
if isinstance(res, PoolingResponse):
|
||||
generator = EmbeddingResponse(
|
||||
id=res.id,
|
||||
object=res.object,
|
||||
created=res.created,
|
||||
model=res.model,
|
||||
data=[
|
||||
EmbeddingResponseData(
|
||||
index=d.index,
|
||||
embedding=d.data, # type: ignore
|
||||
) for d in res.data
|
||||
],
|
||||
usage=res.usage,
|
||||
)
|
||||
else:
|
||||
generator = res
|
||||
else:
|
||||
generator = await handler.create_embedding(request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
|
|
@ -352,6 +440,52 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
|||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/pooling")
|
||||
@with_cancellation
|
||||
async def create_pooling(request: PoolingRequest, raw_request: Request):
|
||||
handler = pooling(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Pooling API")
|
||||
|
||||
generator = await handler.create_pooling(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, PoolingResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/score")
|
||||
@with_cancellation
|
||||
async def create_score(request: ScoreRequest, raw_request: Request):
|
||||
handler = score(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Score API")
|
||||
|
||||
generator = await handler.create_score(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, ScoreResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/v1/score")
|
||||
@with_cancellation
|
||||
async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||
logger.warning(
|
||||
"To indicate that Score API is not part of standard OpenAI API, we "
|
||||
"have moved it to `/score`. Please update your client accordingly.")
|
||||
|
||||
return await create_score(request, raw_request)
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
|
|
@ -380,12 +514,10 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
|||
@router.post("/v1/load_lora_adapter")
|
||||
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await completion(raw_request).load_lora_adapter(request)
|
||||
for route in [chat, completion, embedding]:
|
||||
handler = route(raw_request)
|
||||
if handler is not None:
|
||||
response = await handler.load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
|
@ -395,12 +527,10 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
|||
@router.post("/v1/unload_lora_adapter")
|
||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await completion(raw_request).unload_lora_adapter(request)
|
||||
for route in [chat, completion, embedding]:
|
||||
handler = route(raw_request)
|
||||
if handler is not None:
|
||||
response = await handler.unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
|
@ -431,8 +561,9 @@ def build_app(args: Namespace) -> FastAPI:
|
|||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_, exc):
|
||||
chat = app.state.openai_serving_chat
|
||||
err = chat.create_error_response(message=str(exc))
|
||||
err = ErrorResponse(message=str(exc),
|
||||
type="BadRequestError",
|
||||
code=HTTPStatus.BAD_REQUEST)
|
||||
return JSONResponse(err.model_dump(),
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
|
@ -440,16 +571,31 @@ def build_app(args: Namespace) -> FastAPI:
|
|||
|
||||
@app.middleware("http")
|
||||
async def authentication(request: Request, call_next):
|
||||
root_path = "" if args.root_path is None else args.root_path
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
if not request.url.path.startswith(f"{root_path}/v1"):
|
||||
url_path = request.url.path
|
||||
if app.root_path and url_path.startswith(app.root_path):
|
||||
url_path = url_path[len(app.root_path):]
|
||||
if not url_path.startswith("/v1"):
|
||||
return await call_next(request)
|
||||
if request.headers.get("Authorization") != "Bearer " + token:
|
||||
return JSONResponse(content={"error": "Unauthorized"},
|
||||
status_code=401)
|
||||
return await call_next(request)
|
||||
|
||||
if args.enable_request_id_headers:
|
||||
logger.warning(
|
||||
"CAUTION: Enabling X-Request-Id headers in the API Server. "
|
||||
"This can harm performance at high QPS.")
|
||||
|
||||
@app.middleware("http")
|
||||
async def add_request_id(request: Request, call_next):
|
||||
request_id = request.headers.get(
|
||||
"X-Request-Id") or uuid.uuid4().hex
|
||||
response = await call_next(request)
|
||||
response.headers["X-Request-Id"] = request_id
|
||||
return response
|
||||
|
||||
for middleware in args.middleware:
|
||||
module_path, object_name = middleware.rsplit(".", 1)
|
||||
imported = getattr(importlib.import_module(module_path), object_name)
|
||||
|
|
@ -488,49 +634,179 @@ def init_app_state(
|
|||
state.engine_client = engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||
|
||||
state.openai_serving_models = OpenAIServingModels(
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
)
|
||||
# TODO: The chat template is now broken for lora adapters :(
|
||||
state.openai_serving_chat = OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
state.openai_serving_models,
|
||||
args.response_role,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser)
|
||||
tool_parser=args.tool_call_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if model_config.runner_type == "generate" else None
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
) if model_config.runner_type == "generate" else None
|
||||
state.openai_serving_pooling = OpenAIServingPooling(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
) if model_config.runner_type == "pooling" else None
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
) if model_config.task == "embed" else None
|
||||
state.openai_serving_scores = OpenAIServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger
|
||||
) if model_config.task == "score" else None
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
)
|
||||
state.task = model_config.task
|
||||
# if args.served_model_name is not None:
|
||||
# served_model_names = args.served_model_name
|
||||
# else:
|
||||
# served_model_names = [args.model]
|
||||
|
||||
# if args.disable_log_requests:
|
||||
# request_logger = None
|
||||
# else:
|
||||
# request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
# base_model_paths = [
|
||||
# BaseModelPath(name=name, model_path=args.model)
|
||||
# for name in served_model_names
|
||||
# ]
|
||||
|
||||
# state.engine_client = engine_client
|
||||
# state.log_stats = not args.disable_log_stats
|
||||
|
||||
# resolved_chat_template = load_chat_template(args.chat_template)
|
||||
# logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||
|
||||
# state.openai_serving_chat = OpenAIServingChat(
|
||||
# engine_client,
|
||||
# model_config,
|
||||
# base_model_paths,
|
||||
# args.response_role,
|
||||
# lora_modules=args.lora_modules,
|
||||
# prompt_adapters=args.prompt_adapters,
|
||||
# request_logger=request_logger,
|
||||
# chat_template=resolved_chat_template,
|
||||
# chat_template_content_format=args.chat_template_content_format,
|
||||
# return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
# enable_auto_tools=args.enable_auto_tool_choice,
|
||||
# tool_parser=args.tool_call_parser,
|
||||
# enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
# ) if model_config.runner_type == "generate" else None
|
||||
# state.openai_serving_completion = OpenAIServingCompletion(
|
||||
# engine_client,
|
||||
# model_config,
|
||||
# base_model_paths,
|
||||
# lora_modules=args.lora_modules,
|
||||
# prompt_adapters=args.prompt_adapters,
|
||||
# request_logger=request_logger,
|
||||
# return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
# ) if model_config.runner_type == "generate" else None
|
||||
# state.openai_serving_pooling = OpenAIServingPooling(
|
||||
# engine_client,
|
||||
# model_config,
|
||||
# base_model_paths,
|
||||
# request_logger=request_logger,
|
||||
# chat_template=resolved_chat_template,
|
||||
# chat_template_content_format=args.chat_template_content_format,
|
||||
# ) if model_config.runner_type == "pooling" else None
|
||||
# state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
# engine_client,
|
||||
# model_config,
|
||||
# base_model_paths,
|
||||
# request_logger=request_logger,
|
||||
# chat_template=resolved_chat_template,
|
||||
# chat_template_content_format=args.chat_template_content_format,
|
||||
# ) if model_config.task == "embed" else None
|
||||
# state.openai_serving_scores = OpenAIServingScores(
|
||||
# engine_client,
|
||||
# model_config,
|
||||
# base_model_paths,
|
||||
# request_logger=request_logger
|
||||
# ) if model_config.task == "score" else None
|
||||
# state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
# engine_client,
|
||||
# model_config,
|
||||
# base_model_paths,
|
||||
# lora_modules=args.lora_modules,
|
||||
# request_logger=request_logger,
|
||||
# chat_template=resolved_chat_template,
|
||||
# chat_template_content_format=args.chat_template_content_format,
|
||||
# )
|
||||
|
||||
|
||||
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
|
||||
family = socket.AF_INET
|
||||
if is_valid_ipv6_address(addr[0]):
|
||||
family = socket.AF_INET6
|
||||
|
||||
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(addr)
|
||||
|
||||
return sock
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
temp_socket.bind(("", args.port))
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
valide_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||
if args.enable_auto_tool_choice \
|
||||
and args.tool_call_parser not in valide_tool_parses:
|
||||
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
||||
f"(chose from {{ {','.join(valide_tool_parses)} }})")
|
||||
|
||||
# workaround to make sure that we bind the port before the engine is set up.
|
||||
# This avoids race conditions with ray.
|
||||
# see https://github.com/vllm-project/vllm/issues/8204
|
||||
sock_addr = (args.host or "", args.port)
|
||||
sock = create_server_socket(sock_addr)
|
||||
|
||||
# workaround to avoid footguns where uvicorn drops requests with too
|
||||
# many concurrent requests active
|
||||
set_ulimit()
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
|
|
@ -544,8 +820,6 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|||
model_config = await engine_client.get_model_config()
|
||||
init_app_state(engine_client, model_config, app.state, args)
|
||||
|
||||
temp_socket.close()
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
host=args.host,
|
||||
|
|
@ -562,13 +836,23 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|||
# NB: Await server shutdown only after the backend context is exited
|
||||
await shutdown_task
|
||||
|
||||
sock.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# NOTE(simon):
|
||||
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
|
||||
logger.warning("Warning: Please use `ipex_llm.vllm.xpu.entrypoints.openai.api_server` "
|
||||
"instead of `vllm.entrypoints.openai.api_server` to start the API server")
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser = make_arg_parser(parser)
|
||||
parser.add_argument(
|
||||
"--load-in-low-bit",
|
||||
type=str,
|
||||
default="sym_int4",
|
||||
help="Low-bit quantization for IPEX-LLM models")
|
||||
args = parser.parse_args()
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
uvloop.run(run_server(args))
|
||||
|
|
|
|||
|
|
@ -7,11 +7,14 @@ purposes.
|
|||
import argparse
|
||||
import json
|
||||
import ssl
|
||||
from typing import List, Optional, Sequence, Union
|
||||
from typing import List, Optional, Sequence, Union, get_args
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
validate_chat_template)
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
PromptAdapterPath)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
|
|
@ -130,10 +133,23 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||
help="The file path to the chat template, "
|
||||
"or the template in single-line form "
|
||||
"for the specified model")
|
||||
parser.add_argument(
|
||||
'--chat-template-content-format',
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=get_args(ChatTemplateContentFormatOption),
|
||||
help='The format to render message content within a chat template.'
|
||||
'\n\n'
|
||||
'* "string" will render the content as a string. '
|
||||
'Example: "Hello World"\n'
|
||||
'* "openai" will render the content as a list of dictionaries, '
|
||||
'similar to OpenAI schema. '
|
||||
'Example: [{"type": "text", "text": "Hello world!"}]')
|
||||
parser.add_argument("--response-role",
|
||||
type=nullable_str,
|
||||
default="assistant",
|
||||
help="The role name to return if `request.add_generation_prompt=true`.")
|
||||
help="The role name to return if "
|
||||
"`request.add_generation_prompt=true`.")
|
||||
parser.add_argument("--ssl-keyfile",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
|
|
@ -180,28 +196,36 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||
action="store_true",
|
||||
help="If specified, will run the OpenAI frontend server in the same "
|
||||
"process as the model serving engine.")
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-request-id-headers",
|
||||
action="store_true",
|
||||
help="If specified, API server will add X-Request-Id header to "
|
||||
"responses. Caution: this hurts performance at high QPS.")
|
||||
parser.add_argument(
|
||||
"--enable-auto-tool-choice",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable auto tool choice for supported models. Use --tool-call-parser"
|
||||
"to specify which parser to use")
|
||||
" to specify which parser to use")
|
||||
|
||||
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
|
||||
parser.add_argument(
|
||||
"--tool-call-parser",
|
||||
type=str,
|
||||
choices=["mistral", "hermes"],
|
||||
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
|
||||
"--tool-parser-plugin",
|
||||
default=None,
|
||||
help="Select the tool call parser depending on the model that you're using."
|
||||
" This is used to parse the model-generated tool call into OpenAI API "
|
||||
"format. Required for --enable-auto-tool-choice.")
|
||||
|
||||
parser.add_argument(
|
||||
"--load-in-low-bit",
|
||||
"--tool-parser-plugin",
|
||||
type=str,
|
||||
default="sym_int4",
|
||||
help="Low-bit quantization for IPEX-LLM models")
|
||||
default="",
|
||||
help="Special the tool parser plugin write to parse the model-generated tool"
|
||||
" into OpenAI API format, the name register in this plugin can be used "
|
||||
"in --tool-call-parser.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
|
|
@ -218,10 +242,35 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||
default=False,
|
||||
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-prompt-tokens-details",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="If set to True, enable prompt_tokens_details in usage.")
|
||||
|
||||
parser.add_argument(
|
||||
"--load-in-low-bit",
|
||||
type=str,
|
||||
default="sym_int4",
|
||||
help="Low-bit quantization for IPEX-LLM models")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def validate_parsed_serve_args(args: argparse.Namespace):
|
||||
"""Quick checks for model serve args that raise prior to loading.""" # noqa
|
||||
if hasattr(args, "subparser") and args.subparser != "serve":
|
||||
return
|
||||
|
||||
# Ensure that the chat template is valid; raises if it likely isn't
|
||||
validate_chat_template(args.chat_template)
|
||||
|
||||
# Enable auto tool needs a tool call parser to be valid
|
||||
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires " # noqa
|
||||
"--tool-call-parser")
|
||||
|
||||
|
||||
def create_parser_for_docs() -> FlexibleArgumentParser:
|
||||
parser_for_docs = FlexibleArgumentParser(
|
||||
prog="-m vllm.entrypoints.openai.api_server")
|
||||
|
|
|
|||
23
python/llm/src/ipex_llm/vllm/xpu/ipex_llm_v1_wrapper.py
Normal file
23
python/llm/src/ipex_llm/vllm/xpu/ipex_llm_v1_wrapper.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from vllm.logger import init_logger
|
||||
from vllm.v1.executor.ray_utils import RayWorkerWrapper
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class IPEXLLMV1Wrapper(RayWorkerWrapper):
|
||||
def __init__(self, load_in_low_bit="sym_int4", *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert
|
||||
_ipex_llm_convert(load_in_low_bit=load_in_low_bit)
|
||||
self.compiled_dag_cuda_device_set = False
|
||||
|
||||
|
||||
def get_ipex_llm_v1_wrapper(load_in_low_bit):
|
||||
# The reason why we not using functools.partial is that
|
||||
# ray seems not work well with it.
|
||||
class WrapperWithLoadBit(IPEXLLMV1Wrapper):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(load_in_low_bit=load_in_low_bit, *args, **kwargs)
|
||||
|
||||
return WrapperWithLoadBit
|
||||
|
|
@ -65,9 +65,14 @@ def _model_sample_convert():
|
|||
def _ipex_llm_convert(load_in_low_bit):
|
||||
from vllm.worker.xpu_model_runner import XPUModelRunner
|
||||
from ipex_llm.vllm.xpu.ipex_llm_wrapper import get_ipex_llm_wrapper
|
||||
import vllm.executor.ray_utils as ray_utils
|
||||
from ipex_llm.vllm.xpu.ipex_llm_v1_wrapper import get_ipex_llm_v1_wrapper
|
||||
import vllm.executor.ray_utils as ray_utils_v0
|
||||
import vllm.v1.executor.ray_utils as ray_utils_v1
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
setattr(XPUModelRunner, "load_model", get_load_function(load_in_low_bit))
|
||||
setattr(ray_utils, "RayWorkerWrapper", get_ipex_llm_wrapper(load_in_low_bit))
|
||||
setattr(GPUModelRunner, "load_model", get_load_function(load_in_low_bit))
|
||||
setattr(ray_utils_v0, "RayWorkerWrapper", get_ipex_llm_wrapper(load_in_low_bit))
|
||||
setattr(ray_utils_v1, "RayWorkerWrapper", get_ipex_llm_v1_wrapper(load_in_low_bit))
|
||||
|
||||
|
||||
def get_load_function(low_bit):
|
||||
|
|
@ -77,19 +82,16 @@ def get_load_function(low_bit):
|
|||
# from vllm.utils import measure_device_memory
|
||||
from vllm.utils import DeviceMemoryProfiler
|
||||
with DeviceMemoryProfiler() as m:
|
||||
from dataclasses import replace
|
||||
new_device_config = DeviceConfig("cpu")
|
||||
new_vllm_config = replace(self.vllm_config, device_config=new_device_config)
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
device_config=DeviceConfig("cpu"),
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
cache_config=self.cache_config,
|
||||
vllm_config=new_vllm_config
|
||||
)
|
||||
if "qwen" in self.model_config.model.lower() or \
|
||||
"baichuan" in self.model_config.model.lower() or \
|
||||
"codegeex4-all" in self.model_config.model.lower() or \
|
||||
"chatglm" in self.model_config.model.lower():
|
||||
if "qwen" in self.vllm_config.model_config.model.lower() or \
|
||||
"baichuan" in self.vllm_config.model_config.model.lower() or \
|
||||
"codegeex4-all" in self.vllm_config.model_config.model.lower() or \
|
||||
"chatglm" in self.vllm_config.model_config.model.lower():
|
||||
self.model.apply(padding_mlp)
|
||||
from ipex_llm import optimize_model
|
||||
import os
|
||||
|
|
@ -99,18 +101,22 @@ def get_load_function(low_bit):
|
|||
modules = ["35.mlp", "36.mlp", "37.mlp", "38.mlp", "39.mlp"]
|
||||
else:
|
||||
modules = None
|
||||
if "minicpm" in self.model_config.model.lower():
|
||||
if "minicpm" in self.vllm_config.model_config.model.lower():
|
||||
modules = ["vpm", "resampler"]
|
||||
# only for minicpm_2_6
|
||||
if "minicpm-v" in self.model_config.model.lower():
|
||||
if "minicpm-v" in self.vllm_config.model_config.model.lower():
|
||||
from ipex_llm.transformers.models.minicpmv import merge_qkv
|
||||
self.model.vpm.apply(merge_qkv)
|
||||
if "internvl2" in self.model_config.model.lower():
|
||||
if "internvl2" in self.vllm_config.model_config.model.lower():
|
||||
modules = ["vision_model", "mlp1"]
|
||||
optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype,
|
||||
if "deepseek-v2" in self.vllm_config.model_config.model.lower():
|
||||
modules = ["down_proj"]
|
||||
optimize_model(self.model,
|
||||
low_bit=low_bit,
|
||||
torch_dtype=self.vllm_config.model_config.dtype,
|
||||
modules_to_not_convert=modules)
|
||||
self.model = self.model.to(device=self.device_config.device,
|
||||
dtype=self.model_config.dtype)
|
||||
self.model = self.model.to(device=self.vllm_config.device_config.device,
|
||||
dtype=self.vllm_config.model_config.dtype)
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger = init_logger(__name__)
|
||||
|
|
|
|||
Loading…
Reference in a new issue