Switching from vLLM v0.3.3 to vLLM 0.5.4 (#12042)
* Enable single card sync engine * enable ipex-llm optimizations for vllm * enable optimizations for lm_head * Fix chatglm multi-reference problem * Remove duplicate layer * LLM: Update vLLM to v0.5.4 (#11746) * Enable single card sync engine * enable ipex-llm optimizations for vllm * enable optimizations for lm_head * Fix chatglm multi-reference problem * update 0.5.4 api_server * add dockerfile * fix * fix * refine * fix --------- Co-authored-by: gc-fu <guancheng.fu@intel.com> * Add vllm-0.5.4 Dockerfile (#11838) * Update BIGDL_LLM_SDP_IGNORE_MASK in start-vllm-service.sh (#11957) * Fix vLLM not convert issues (#11817) (#11918) * Fix not convert issues * refine Co-authored-by: Guancheng Fu <110874468+gc-fu@users.noreply.github.com> * Fix glm4-9b-chat nan error on vllm 0.5.4 (#11969) * init * update mlp forward * fix minicpm error in vllm 0.5.4 * fix dependabot alerts (#12008) * Update 0.5.4 dockerfile (#12021) * Add vllm awq loading logic (#11987) * [ADD] Add vllm awq loading logic * [FIX] fix the module.linear_method path * [FIX] fix quant_config path error * Enable Qwen padding mlp to 256 to support batch_forward (#12030) * Enable padding mlp * padding to 256 * update style * Install 27191 runtime in 0.5.4 docker image (#12040) * fix rebase error * fix rebase error * vLLM: format for 0.5.4 rebase (#12043) * format * Update model_convert.py * Fix serving docker related modifications (#12046) * Fix undesired modifications (#12048) * fix * Refine offline_inference arguments --------- Co-authored-by: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com> Co-authored-by: Jun Wang <thoughts.times@gmail.com> Co-authored-by: Wang, Jian4 <61138589+hzjane@users.noreply.github.com> Co-authored-by: liu-shaojun <johnssalyn@outlook.com> Co-authored-by: Shaojun Liu <61072813+liu-shaojun@users.noreply.github.com>
This commit is contained in:
parent
73a4360f3f
commit
69c8d36f16
14 changed files with 903 additions and 1009 deletions
|
|
@ -1,60 +1,90 @@
|
||||||
FROM intelanalytics/ipex-llm-serving-xpu:latest as build
|
FROM intel/oneapi-basekit:2024.1.1-devel-ubuntu22.04
|
||||||
|
|
||||||
ARG http_proxy
|
ARG http_proxy
|
||||||
ARG https_proxy
|
ARG https_proxy
|
||||||
|
|
||||||
ADD ./oneccl-binding.patch /tmp/oneccl-binding.patch
|
ENV TZ=Asia/Shanghai
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
RUN cd /tmp/ && \
|
|
||||||
pip install --upgrade setuptools wheel twine && \
|
|
||||||
pip install "setuptools<70.0.0" && \
|
|
||||||
git clone https://github.com/intel/torch-ccl -b v2.1.100+xpu && \
|
|
||||||
cd torch-ccl && \
|
|
||||||
patch -p1 < /tmp/oneccl-binding.patch && \
|
|
||||||
git submodule sync && \
|
|
||||||
git submodule update --init --recursive && \
|
|
||||||
COMPUTE_BACKEND=dpcpp python setup.py sdist bdist_wheel && \
|
|
||||||
mv /tmp/torch-ccl/dist/oneccl_bind_pt-2.1.100+xpu-cp311-cp311-linux_x86_64.whl /tmp/
|
|
||||||
|
|
||||||
|
|
||||||
FROM intelanalytics/ipex-llm-xpu:2.2.0-SNAPSHOT
|
|
||||||
|
|
||||||
ARG http_proxy
|
|
||||||
ARG https_proxy
|
|
||||||
|
|
||||||
# Disable pip's cache behavior
|
# Disable pip's cache behavior
|
||||||
ARG PIP_NO_CACHE_DIR=false
|
ARG PIP_NO_CACHE_DIR=false
|
||||||
COPY --from=build /tmp/oneccl_bind_pt-2.1.100+xpu-cp311-cp311-linux_x86_64.whl /tmp/
|
ADD ./gradio_web_server.patch /tmp/gradio_web_server.patch
|
||||||
ADD ./gradio_web_server.patch /tmp/gradio_web_server.patch
|
ADD ./oneccl-binding.patch /tmp/oneccl-binding.patch
|
||||||
|
|
||||||
# Install Serving Dependencies
|
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \
|
||||||
# Install ipex-llm[serving] only will update ipex_llm source code without updating
|
echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \
|
||||||
# bigdl-core-xe, which will lead to problems
|
chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \
|
||||||
RUN apt-get update && \
|
rm /etc/apt/sources.list.d/intel-graphics.list && \
|
||||||
apt-get install -y --no-install-recommends libfabric-dev wrk libaio-dev && \
|
wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \
|
||||||
apt-get install -y intel-opencl-icd intel-level-zero-gpu=1.3.26241.33-647~22.04 level-zero level-zero-dev --allow-downgrades && \
|
echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \
|
||||||
pip install --pre --upgrade ipex-llm[xpu,serving] && \
|
chmod 644 /usr/share/keyrings/intel-graphics.gpg && \
|
||||||
pip install transformers==4.37.0 gradio==4.19.2 && \
|
apt-get update && \
|
||||||
# Install vLLM-v2 dependencies
|
apt-get install -y --no-install-recommends curl wget git libunwind8-dev vim less && \
|
||||||
git clone -b sycl_xpu https://github.com/analytics-zoo/vllm.git /llm/vllm && \
|
ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && \
|
||||||
pip install -r /llm/vllm/requirements-xpu.txt && \
|
env DEBIAN_FRONTEND=noninteractive apt-get update && \
|
||||||
pip install --no-deps xformers && \
|
# add-apt-repository requires gnupg, gpg-agent, software-properties-common
|
||||||
VLLM_BUILD_XPU_OPS=1 pip install --no-build-isolation -v -e /llm/vllm && \
|
apt-get install -y --no-install-recommends gnupg gpg-agent software-properties-common && \
|
||||||
pip install outlines==0.0.34 --no-deps && \
|
# Add Python 3.11 PPA repository
|
||||||
pip install interegular cloudpickle diskcache joblib lark nest-asyncio numba scipy && \
|
add-apt-repository ppa:deadsnakes/ppa -y && \
|
||||||
# For Qwen series models support
|
apt-get install -y --no-install-recommends python3.11 git curl wget && \
|
||||||
|
rm /usr/bin/python3 && \
|
||||||
|
ln -s /usr/bin/python3.11 /usr/bin/python3 && \
|
||||||
|
ln -s /usr/bin/python3 /usr/bin/python && \
|
||||||
|
apt-get install -y --no-install-recommends python3-pip python3.11-dev python3-wheel python3.11-distutils && \
|
||||||
|
wget https://bootstrap.pypa.io/get-pip.py -O get-pip.py && \
|
||||||
|
# Install FastChat from source requires PEP 660 support
|
||||||
|
python3 get-pip.py && \
|
||||||
|
rm get-pip.py && \
|
||||||
|
pip install --upgrade requests argparse urllib3 && \
|
||||||
|
pip install --pre --upgrade ipex-llm[xpu,serving] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ && \
|
||||||
|
pip install transformers==4.36.2 && \
|
||||||
pip install transformers_stream_generator einops tiktoken && \
|
pip install transformers_stream_generator einops tiktoken && \
|
||||||
# For pipeline serving support
|
pip install --upgrade colorama && \
|
||||||
pip install mpi4py fastapi uvicorn openai && \
|
# Download all-in-one benchmark and examples
|
||||||
# for gradio web UI
|
git clone https://github.com/intel-analytics/ipex-llm && \
|
||||||
pip install gradio && \
|
cp -r ./ipex-llm/python/llm/dev/benchmark/ ./benchmark && \
|
||||||
# Install internal oneccl && \
|
cp -r ./ipex-llm/python/llm/example/GPU/HuggingFace/LLM ./examples && \
|
||||||
|
# Install vllm dependencies
|
||||||
|
pip install --upgrade fastapi && \
|
||||||
|
pip install --upgrade "uvicorn[standard]" && \
|
||||||
|
# Download vLLM-Serving
|
||||||
|
cp -r ./ipex-llm/python/llm/example/GPU/vLLM-Serving/ ./vLLM-Serving && \
|
||||||
|
rm -rf ./ipex-llm && \
|
||||||
|
# Install torch-ccl
|
||||||
cd /tmp/ && \
|
cd /tmp/ && \
|
||||||
|
pip install torch==2.1.0.post2 torchvision==0.16.0.post2 torchaudio==2.1.0.post2 intel-extension-for-pytorch==2.1.30.post0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ && \
|
||||||
|
# Internal oneccl
|
||||||
wget https://sourceforge.net/projects/oneccl-wks/files/oneccl_wks_installer_2024.0.0.2.sh && \
|
wget https://sourceforge.net/projects/oneccl-wks/files/oneccl_wks_installer_2024.0.0.2.sh && \
|
||||||
bash oneccl_wks_installer_2024.0.0.2.sh && \
|
bash oneccl_wks_installer_2024.0.0.2.sh && \
|
||||||
pip uninstall -y oneccl_bind_pt && \
|
git clone https://github.com/intel/torch-ccl -b v2.1.300+xpu && \
|
||||||
pip install /tmp/oneccl_bind_pt-2.1.100+xpu-cp311-cp311-linux_x86_64.whl && \
|
cd torch-ccl && \
|
||||||
rm /tmp/oneccl_bind_pt-2.1.100+xpu-cp311-cp311-linux_x86_64.whl && \
|
patch -p1 < /tmp/oneccl-binding.patch && \
|
||||||
|
USE_SYSTEM_ONECCL=ON COMPUTE_BACKEND=dpcpp python setup.py install && \
|
||||||
|
apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends libfabric-dev wrk libaio-dev numactl && \
|
||||||
|
# apt-get install -y intel-opencl-icd intel-level-zero-gpu=1.3.26241.33-647~22.04 level-zero level-zero-dev --allow-downgrades && \
|
||||||
|
mkdir -p /tmp/neo && \
|
||||||
|
cd /tmp/neo && \
|
||||||
|
wget https://github.com/intel/intel-graphics-compiler/releases/download/igc-1.0.15136.4/intel-igc-core_1.0.15136.4_amd64.deb && \
|
||||||
|
wget https://github.com/intel/intel-graphics-compiler/releases/download/igc-1.0.15136.4/intel-igc-opencl_1.0.15136.4_amd64.deb && \
|
||||||
|
wget https://github.com/intel/compute-runtime/releases/download/23.35.27191.9/intel-level-zero-gpu-dbgsym_1.3.27191.9_amd64.ddeb && \
|
||||||
|
wget https://github.com/intel/compute-runtime/releases/download/23.35.27191.9/intel-level-zero-gpu_1.3.27191.9_amd64.deb && \
|
||||||
|
wget https://github.com/intel/compute-runtime/releases/download/23.35.27191.9/intel-opencl-icd-dbgsym_23.35.27191.9_amd64.ddeb && \
|
||||||
|
wget https://github.com/intel/compute-runtime/releases/download/23.35.27191.9/intel-opencl-icd_23.35.27191.9_amd64.deb && \
|
||||||
|
wget https://github.com/intel/compute-runtime/releases/download/23.35.27191.9/libigdgmm12_22.3.11.ci17747749_amd64.deb && \
|
||||||
|
dpkg -i *.deb && \
|
||||||
|
rm -rf /tmp/neo && \
|
||||||
|
mkdir -p /llm && \
|
||||||
|
cd /llm && \
|
||||||
|
git clone -b 0.5.4 https://github.com/analytics-zoo/vllm.git /llm/vllm && \
|
||||||
|
cd /llm/vllm && \
|
||||||
|
pip install -r /llm/vllm/requirements-xpu.txt && \
|
||||||
|
VLLM_TARGET_DEVICE=xpu python setup.py install && \
|
||||||
|
pip install mpi4py fastapi uvicorn openai && \
|
||||||
|
pip install gradio && \
|
||||||
|
# patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch && \
|
||||||
|
pip install ray && \
|
||||||
patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch
|
patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch
|
||||||
|
|
||||||
COPY ./vllm_online_benchmark.py /llm/
|
COPY ./vllm_online_benchmark.py /llm/
|
||||||
|
|
@ -66,5 +96,7 @@ COPY ./start-fastchat-service.sh /llm/
|
||||||
COPY ./start-pp_serving-service.sh /llm/
|
COPY ./start-pp_serving-service.sh /llm/
|
||||||
COPY ./start-lightweight_serving-service.sh /llm/
|
COPY ./start-lightweight_serving-service.sh /llm/
|
||||||
|
|
||||||
|
ENV LD_LIBRARY_PATH /usr/local/lib/python3.11/dist-packages/intel_extension_for_pytorch/lib/:/opt/intel/oneapi/tbb/2021.12/env/../lib/intel64/gcc4.8:/opt/intel/oneapi/mpi/2021.12/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/2021.12/lib:/opt/intel/oneapi/mkl/2024.1/lib:/opt/intel/oneapi/ippcp/2021.11/lib/:/opt/intel/oneapi/ipp/2021.11/lib:/opt/intel/oneapi/dpl/2022.5/lib:/opt/intel/oneapi/dnnl/2024.1/lib:/opt/intel/oneapi/debugger/2024.1/opt/debugger/lib:/opt/intel/oneapi/dal/2024.2/lib:/opt/intel/oneapi/compiler/2024.1/opt/oclfpga/host/linux64/lib:/opt/intel/oneapi/compiler/2024.1/opt/compiler/lib:/opt/intel/oneapi/compiler/2024.1/lib:/opt/intel/oneapi/ccl/2021.12/lib/
|
||||||
|
ENV BIGDL_LLM_SDP_IGNORE_MASK 0
|
||||||
|
|
||||||
WORKDIR /llm/
|
WORKDIR /llm/
|
||||||
|
|
|
||||||
|
|
@ -103,41 +103,38 @@ def run_vllm(
|
||||||
warm_prompt = "hi " * (1024 - 1)
|
warm_prompt = "hi " * (1024 - 1)
|
||||||
warm_requests = [(warm_prompt, 1024, 1024)
|
warm_requests = [(warm_prompt, 1024, 1024)
|
||||||
for _ in range(8)]
|
for _ in range(8)]
|
||||||
for prompt, _, output_len in warm_requests:
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
n=n,
|
|
||||||
temperature=0.0 if use_beam_search else 1.0,
|
|
||||||
top_p=1.0,
|
|
||||||
use_beam_search=use_beam_search,
|
|
||||||
ignore_eos=True,
|
|
||||||
max_tokens=output_len,
|
|
||||||
)
|
|
||||||
llm._add_request(
|
|
||||||
prompt=prompt,
|
|
||||||
prompt_token_ids=None,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
)
|
|
||||||
llm._run_engine(use_tqdm=True)
|
|
||||||
|
|
||||||
|
prompts: List[str] = []
|
||||||
|
sampling_params: List[SamplingParams] = []
|
||||||
|
for prompt, _, output_len in warm_requests:
|
||||||
|
prompts.append(prompt)
|
||||||
|
sampling_params.append(
|
||||||
|
SamplingParams(
|
||||||
|
n=n,
|
||||||
|
temperature=0.0 if use_beam_search else 1.0,
|
||||||
|
top_p=1.0,
|
||||||
|
use_beam_search=use_beam_search,
|
||||||
|
ignore_eos=True,
|
||||||
|
max_tokens=output_len,
|
||||||
|
))
|
||||||
|
llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||||
|
|
||||||
|
prompts: List[str] = []
|
||||||
|
sampling_params: List[SamplingParams] = []
|
||||||
for prompt, _, output_len in requests:
|
for prompt, _, output_len in requests:
|
||||||
sampling_params = SamplingParams(
|
prompts.append(prompt)
|
||||||
n=n,
|
sampling_params.append(
|
||||||
temperature=0.0 if use_beam_search else 1.0,
|
SamplingParams(
|
||||||
top_p=1.0,
|
n=n,
|
||||||
use_beam_search=use_beam_search,
|
temperature=0.0 if use_beam_search else 1.0,
|
||||||
ignore_eos=True,
|
top_p=1.0,
|
||||||
max_tokens=output_len,
|
use_beam_search=use_beam_search,
|
||||||
)
|
ignore_eos=True,
|
||||||
# FIXME(woosuk): Do not use internal method.
|
max_tokens=output_len,
|
||||||
llm._add_request(
|
))
|
||||||
prompt=prompt,
|
|
||||||
prompt_token_ids=None,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
# FIXME(woosuk): Do not use internal method.
|
llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||||
llm._run_engine(use_tqdm=True)
|
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
return end - start
|
return end - start
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ model="YOUR_MODEL_PATH"
|
||||||
served_model_name="YOUR_MODEL_NAME"
|
served_model_name="YOUR_MODEL_NAME"
|
||||||
|
|
||||||
source /opt/intel/1ccl-wks/setvars.sh
|
source /opt/intel/1ccl-wks/setvars.sh
|
||||||
|
export BIGDL_LLM_SDP_IGNORE_MASK=0
|
||||||
|
|
||||||
python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
||||||
--served-model-name $served_model_name \
|
--served-model-name $served_model_name \
|
||||||
|
|
@ -17,4 +18,4 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
||||||
--max-model-len 4096 \
|
--max-model-len 4096 \
|
||||||
--max-num-batched-tokens 10240 \
|
--max-num-batched-tokens 10240 \
|
||||||
--max-num-seqs 12 \
|
--max-num-seqs 12 \
|
||||||
--tensor-parallel-size 1
|
--tensor-parallel-size 1
|
||||||
|
|
|
||||||
|
|
@ -49,8 +49,10 @@ llm = LLM(model="YOUR_MODEL",
|
||||||
device="xpu",
|
device="xpu",
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
load_in_low_bit="sym_int4",
|
load_in_low_bit="fp8",
|
||||||
tensor_parallel_size=1)
|
tensor_parallel_size=1,
|
||||||
|
max_model_len=2000,
|
||||||
|
max_num_batched_tokens=2000)
|
||||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
# that contain the prompt, generated text, and other information.
|
# that contain the prompt, generated text, and other information.
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
|
||||||
PYTHON_ROOT_DIR="$SCRIPT_DIR/.."
|
PYTHON_ROOT_DIR="$SCRIPT_DIR/.."
|
||||||
echo $PYTHON_ROOT_DIR
|
echo $PYTHON_ROOT_DIR
|
||||||
PATHS_TO_CHECK="$SCRIPT_DIR/../../src"
|
PATHS_TO_CHECK="$SCRIPT_DIR/../../src"
|
||||||
PATTERNS_TO_EXCLUDE="__init__.py,log4Error.py,$SCRIPT_DIR/../../src/ipex_llm/langchain/*,$SCRIPT_DIR/../../src/ipex_llm/transformers/gguf/models/model_implement/yuan2/*,benchmark_util_4_29.py,benchmark_util_4_42.py,benchmark_util_4_43.py,tgi_api_server.py"
|
PATTERNS_TO_EXCLUDE="__init__.py,log4Error.py,$SCRIPT_DIR/../../src/ipex_llm/langchain/*,$SCRIPT_DIR/../../src/ipex_llm/transformers/gguf/models/model_implement/yuan2/*,benchmark_util_4_29.py,benchmark_util_4_42.py,benchmark_util_4_43.py,tgi_api_server.py,api_server.py"
|
||||||
PEP8_REPORT_PATH="$PYTHON_ROOT_DIR/test/pep8-report.txt"
|
PEP8_REPORT_PATH="$PYTHON_ROOT_DIR/test/pep8-report.txt"
|
||||||
PYLINT_REPORT_PATH="$PYTHON_ROOT_DIR/test/pylint-report.txt"
|
PYLINT_REPORT_PATH="$PYTHON_ROOT_DIR/test/pylint-report.txt"
|
||||||
PYLINT_INSTALL_INFO="$PYTHON_ROOT_DIR/test/pylint-info.txt"
|
PYLINT_INSTALL_INFO="$PYTHON_ROOT_DIR/test/pylint-info.txt"
|
||||||
|
|
|
||||||
|
|
@ -160,7 +160,7 @@ def is_linear_module(module):
|
||||||
if is_module_in_classes(module, VLLM_LINEAR_LIST):
|
if is_module_in_classes(module, VLLM_LINEAR_LIST):
|
||||||
if 'xpu' in _VLLM_VERSION:
|
if 'xpu' in _VLLM_VERSION:
|
||||||
# For vllm xpu
|
# For vllm xpu
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_tensor_model_parallel_group,
|
get_tensor_model_parallel_group,
|
||||||
get_tensor_model_parallel_world_size
|
get_tensor_model_parallel_world_size
|
||||||
)
|
)
|
||||||
|
|
@ -183,8 +183,8 @@ def is_linear_module(module):
|
||||||
mp_group = None
|
mp_group = None
|
||||||
# Check for attribute qweight
|
# Check for attribute qweight
|
||||||
if (not _USE_VLLM_AWQ
|
if (not _USE_VLLM_AWQ
|
||||||
and hasattr(module.linear_method, "quant_config")
|
and hasattr(module.quant_method, "quant_config")
|
||||||
and module.linear_method.quant_config.get_name() == "awq"):
|
and module.quant_method.quant_config.get_name() == "awq"):
|
||||||
_USE_VLLM_AWQ = True
|
_USE_VLLM_AWQ = True
|
||||||
invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not"
|
invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not"
|
||||||
" support linear layers with skip_bias_add argument")
|
" support linear layers with skip_bias_add argument")
|
||||||
|
|
@ -231,7 +231,6 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, \
|
from ipex_llm.transformers.low_bit_linear import LowBitLinear, \
|
||||||
FP16Linear, BF16Linear, vLLMLowBitLinear, vLLMFP16Linear, vLLMBF16Linear
|
FP16Linear, BF16Linear, vLLMLowBitLinear, vLLMFP16Linear, vLLMBF16Linear
|
||||||
# Currently, vLLM does not support optimize_lm_head = True
|
|
||||||
optimize_lm_head = False
|
optimize_lm_head = False
|
||||||
if isinstance(module, ParallelLMHead):
|
if isinstance(module, ParallelLMHead):
|
||||||
if qtype == ggml_tensor_qtype["fp16"]:
|
if qtype == ggml_tensor_qtype["fp16"]:
|
||||||
|
|
@ -301,7 +300,7 @@ def convert_vllm_awq(module):
|
||||||
dtype=torch.int32) * 4).unsqueeze(0)
|
dtype=torch.int32) * 4).unsqueeze(0)
|
||||||
# vLLM only supports load 4-bits model, so this has been checked
|
# vLLM only supports load 4-bits model, so this has been checked
|
||||||
bits = 4
|
bits = 4
|
||||||
group_size = module.linear_method.quant_config.group_size
|
group_size = module.quant_method.quant_config.group_size
|
||||||
|
|
||||||
zeros = torch.bitwise_right_shift(
|
zeros = torch.bitwise_right_shift(
|
||||||
torch.unsqueeze(module.qzeros, 2).expand(-1, -1, 32 // bits),
|
torch.unsqueeze(module.qzeros, 2).expand(-1, -1, 32 // bits),
|
||||||
|
|
@ -466,6 +465,12 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
if any(key in full_module_name for key in modules_to_not_convert):
|
if any(key in full_module_name for key in modules_to_not_convert):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if is_linear and getattr(model_config, "model_type", None) == "chatglm" and \
|
||||||
|
name == "lm_head":
|
||||||
|
# Now we re-reference it to output_layer
|
||||||
|
model._modules[name] = model._modules["transformer"]._modules["output_layer"]
|
||||||
|
continue
|
||||||
|
|
||||||
if is_linear and not isinstance(module, LowBitLinear):
|
if is_linear and not isinstance(module, LowBitLinear):
|
||||||
in_features, out_features, mp_group = linear_args
|
in_features, out_features, mp_group = linear_args
|
||||||
optimize_lm_head = (
|
optimize_lm_head = (
|
||||||
|
|
|
||||||
|
|
@ -802,7 +802,7 @@ class LowBitLinear(nn.Linear):
|
||||||
result = result.view(new_shape)
|
result = result.view(new_shape)
|
||||||
if self.mp_group is not None:
|
if self.mp_group is not None:
|
||||||
if get_use_vllm():
|
if get_use_vllm():
|
||||||
torch.distributed.all_reduce(result, group=self.mp_group)
|
result = self.mp_group.all_reduce(result)
|
||||||
elif is_deepspeed_available():
|
elif is_deepspeed_available():
|
||||||
from deepspeed import comm as dist
|
from deepspeed import comm as dist
|
||||||
dist.inference_all_reduce(result, group=self.mp_group)
|
dist.inference_all_reduce(result, group=self.mp_group)
|
||||||
|
|
@ -889,7 +889,7 @@ class FP16Linear(nn.Linear):
|
||||||
result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
|
result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
|
||||||
if self.mp_group is not None:
|
if self.mp_group is not None:
|
||||||
if get_use_vllm():
|
if get_use_vllm():
|
||||||
torch.distributed.all_reduce(result, group=self.mp_group)
|
result = self.mp_group.all_reduce(result)
|
||||||
elif is_deepspeed_available():
|
elif is_deepspeed_available():
|
||||||
from deepspeed import comm as dist
|
from deepspeed import comm as dist
|
||||||
dist.inference_all_reduce(result, group=self.mp_group)
|
dist.inference_all_reduce(result, group=self.mp_group)
|
||||||
|
|
@ -926,7 +926,7 @@ class FP16Linear(nn.Linear):
|
||||||
result = result.view(new_shape)
|
result = result.view(new_shape)
|
||||||
if self.mp_group is not None:
|
if self.mp_group is not None:
|
||||||
if get_use_vllm():
|
if get_use_vllm():
|
||||||
torch.distributed.all_reduce(result, group=self.mp_group)
|
result = self.mp_group.all_reduce(result)
|
||||||
elif is_deepspeed_available():
|
elif is_deepspeed_available():
|
||||||
from deepspeed import comm as dist
|
from deepspeed import comm as dist
|
||||||
dist.inference_all_reduce(result, group=self.mp_group)
|
dist.inference_all_reduce(result, group=self.mp_group)
|
||||||
|
|
|
||||||
|
|
@ -13,16 +13,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
from typing import Dict, Optional
|
||||||
from typing import List, Optional, Union
|
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||||
from vllm.engine.ray_utils import initialize_ray_cluster
|
|
||||||
from vllm.entrypoints.llm import LLM
|
from vllm.entrypoints.llm import LLM
|
||||||
from vllm.utils import Counter
|
from vllm.utils import Counter
|
||||||
from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert
|
from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
from vllm.engine.metrics import StatLoggerBase
|
||||||
|
|
||||||
|
|
||||||
class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
||||||
|
|
@ -34,35 +33,14 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
||||||
cls,
|
cls,
|
||||||
engine_args: AsyncEngineArgs,
|
engine_args: AsyncEngineArgs,
|
||||||
start_engine_loop: bool = True,
|
start_engine_loop: bool = True,
|
||||||
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
load_in_low_bit: str = "sym_int4",
|
load_in_low_bit: str = "sym_int4",
|
||||||
# ipex_llm_optimize_mode: str = 'NATIVE',
|
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
|
||||||
) -> "AsyncLLMEngine":
|
) -> "AsyncLLMEngine":
|
||||||
"""Creates an async LLM engine from the engine arguments."""
|
"""Creates an async LLM engine from the engine arguments."""
|
||||||
# Enable ipex-llm optimizations
|
# Create the engine configs.
|
||||||
engine_configs = engine_args.create_engine_configs()
|
|
||||||
|
|
||||||
_ipex_llm_convert(load_in_low_bit)
|
_ipex_llm_convert(load_in_low_bit)
|
||||||
parallel_config = engine_configs[2]
|
return super().from_engine_args(engine_args, start_engine_loop, usage_context, stat_loggers)
|
||||||
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
|
|
||||||
initialize_ray_cluster(parallel_config)
|
|
||||||
# from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
|
||||||
from ipex_llm.vllm.xpu.ipex_llm_gpu_executor import get_gpu_executor_class_async
|
|
||||||
executor_class = get_gpu_executor_class_async(load_in_low_bit)
|
|
||||||
else:
|
|
||||||
invalidInputError(parallel_config.world_size == 1, (
|
|
||||||
"Ray is required if parallel_config.world_size > 1."))
|
|
||||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
|
||||||
executor_class = GPUExecutorAsync
|
|
||||||
# Create the async LLM engine.
|
|
||||||
engine = cls(parallel_config.worker_use_ray,
|
|
||||||
engine_args.engine_use_ray,
|
|
||||||
*engine_configs,
|
|
||||||
executor_class,
|
|
||||||
log_requests=not engine_args.disable_log_requests,
|
|
||||||
log_stats=not engine_args.disable_log_stats,
|
|
||||||
max_log_len=engine_args.max_log_len,
|
|
||||||
start_engine_loop=start_engine_loop)
|
|
||||||
return engine
|
|
||||||
|
|
||||||
|
|
||||||
class IPEXLLMClass(LLM):
|
class IPEXLLMClass(LLM):
|
||||||
|
|
@ -71,6 +49,7 @@ class IPEXLLMClass(LLM):
|
||||||
model: str,
|
model: str,
|
||||||
tokenizer: Optional[str] = None,
|
tokenizer: Optional[str] = None,
|
||||||
tokenizer_mode: str = "auto",
|
tokenizer_mode: str = "auto",
|
||||||
|
skip_tokenizer_init: bool = False,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
dtype: str = "auto",
|
dtype: str = "auto",
|
||||||
|
|
@ -80,18 +59,26 @@ class IPEXLLMClass(LLM):
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
gpu_memory_utilization: float = 0.9,
|
gpu_memory_utilization: float = 0.9,
|
||||||
swap_space: int = 4,
|
swap_space: int = 4,
|
||||||
|
cpu_offload_gb: float = 0,
|
||||||
enforce_eager: bool = False,
|
enforce_eager: bool = False,
|
||||||
max_context_len_to_capture: int = 8192,
|
max_context_len_to_capture: Optional[int] = None,
|
||||||
|
max_seq_len_to_capture: int = 8192,
|
||||||
disable_custom_all_reduce: bool = False,
|
disable_custom_all_reduce: bool = False,
|
||||||
load_in_low_bit: str = "sym_int4",
|
load_in_low_bit: str = "sym_int4",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
if "disable_log_stats" not in kwargs:
|
if "disable_log_stats" not in kwargs:
|
||||||
kwargs["disable_log_stats"] = True
|
kwargs["disable_log_stats"] = True
|
||||||
|
removed_vision_keys = ("image_token_id", "image_feature_size",
|
||||||
|
"image_input_shape", "image_input_type")
|
||||||
|
if any(k in kwargs for k in removed_vision_keys):
|
||||||
|
raise TypeError( # noqa
|
||||||
|
"There is no need to pass vision-related arguments anymore.")
|
||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode,
|
||||||
|
skip_tokenizer_init=skip_tokenizer_init,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
@ -101,13 +88,16 @@ class IPEXLLMClass(LLM):
|
||||||
seed=seed,
|
seed=seed,
|
||||||
gpu_memory_utilization=gpu_memory_utilization,
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
swap_space=swap_space,
|
swap_space=swap_space,
|
||||||
|
cpu_offload_gb=cpu_offload_gb,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
max_context_len_to_capture=max_context_len_to_capture,
|
max_context_len_to_capture=max_context_len_to_capture,
|
||||||
|
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.llm_engine = IPEXLLMLLMEngine.from_engine_args(engine_args,
|
self.llm_engine = IPEXLLMLLMEngine.from_engine_args(
|
||||||
load_in_low_bit=load_in_low_bit)
|
engine_args, usage_context=UsageContext.LLM_CLASS,
|
||||||
|
load_in_low_bit=load_in_low_bit)
|
||||||
self.request_counter = Counter()
|
self.request_counter = Counter()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -119,29 +109,11 @@ class IPEXLLMLLMEngine(LLMEngine):
|
||||||
def from_engine_args(
|
def from_engine_args(
|
||||||
cls,
|
cls,
|
||||||
engine_args: EngineArgs,
|
engine_args: EngineArgs,
|
||||||
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
|
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
|
||||||
load_in_low_bit: str = "sym_int4",
|
load_in_low_bit: str = "sym_int4",
|
||||||
# ipex_llm_optimize_mode: str = 'NATIVE',
|
|
||||||
) -> "LLMEngine":
|
) -> "LLMEngine":
|
||||||
"""Creates an LLM engine from the engine arguments."""
|
"""Creates an LLM engine from the engine arguments."""
|
||||||
# Create the engine configs.
|
# Create the engine configs.
|
||||||
engine_configs = engine_args.create_engine_configs()
|
|
||||||
_ipex_llm_convert(load_in_low_bit)
|
_ipex_llm_convert(load_in_low_bit)
|
||||||
parallel_config = engine_configs[2]
|
return super().from_engine_args(engine_args, usage_context, stat_loggers)
|
||||||
|
|
||||||
# Initialize the cluster and specify the executor class.
|
|
||||||
if parallel_config.worker_use_ray:
|
|
||||||
initialize_ray_cluster(parallel_config)
|
|
||||||
# from vllm.executor.ray_gpu_executor import RayGPUExecutor
|
|
||||||
from ipex_llm.vllm.xpu.ipex_llm_gpu_executor import get_gpu_executor_class
|
|
||||||
executor_class = get_gpu_executor_class(load_in_low_bit)
|
|
||||||
else:
|
|
||||||
invalidInputError(parallel_config.world_size == 1,
|
|
||||||
"Ray is required if parallel_config.world_size > 1.")
|
|
||||||
from vllm.executor.gpu_executor import GPUExecutor
|
|
||||||
executor_class = GPUExecutor
|
|
||||||
|
|
||||||
# Create the LLM engine.
|
|
||||||
engine = cls(*engine_configs,
|
|
||||||
executor_class=executor_class,
|
|
||||||
log_stats=not engine_args.disable_log_stats)
|
|
||||||
return engine
|
|
||||||
|
|
|
||||||
|
|
@ -1,199 +1,191 @@
|
||||||
import argparse
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
import os
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import ssl
|
import re
|
||||||
|
from argparse import Namespace
|
||||||
from prometheus_client import make_asgi_app
|
from contextlib import asynccontextmanager
|
||||||
import fastapi
|
|
||||||
import uvicorn
|
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from fastapi import Request
|
from multiprocessing import Process
|
||||||
|
from typing import AsyncIterator, Set
|
||||||
|
|
||||||
|
from fastapi import APIRouter, FastAPI, Request
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse, Response
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
|
from prometheus_client import make_asgi_app
|
||||||
|
from starlette.routing import Mount
|
||||||
|
|
||||||
import vllm
|
import vllm.envs as envs
|
||||||
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
|
||||||
ChatCompletionRequest,
|
from vllm.engine.protocol import AsyncEngineClient
|
||||||
ErrorResponse)
|
from vllm.entrypoints.launcher import serve_http
|
||||||
from vllm.logger import init_logger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
|
from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
CompletionRequest,
|
||||||
|
DetokenizeRequest,
|
||||||
|
DetokenizeResponse,
|
||||||
|
EmbeddingRequest, ErrorResponse,
|
||||||
|
TokenizeRequest,
|
||||||
|
TokenizeResponse)
|
||||||
|
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
|
||||||
|
from ipex_llm.vllm.xpu.entrypoints.openai.rpc.server import run_rpc_server
|
||||||
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
from vllm.entrypoints.openai.serving_engine import LoRA
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine
|
from vllm.entrypoints.openai.serving_tokenization import (
|
||||||
from ipex_llm.utils.common import invalidInputError
|
OpenAIServingTokenization)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
from vllm.utils import FlexibleArgumentParser, get_open_port
|
||||||
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||||
|
|
||||||
openai_serving_chat: OpenAIServingChat = None
|
async_engine_client: AsyncEngineClient
|
||||||
openai_serving_completion: OpenAIServingCompletion = None
|
engine_args: AsyncEngineArgs
|
||||||
logger = init_logger(__name__)
|
openai_serving_chat: OpenAIServingChat
|
||||||
|
openai_serving_completion: OpenAIServingCompletion
|
||||||
|
openai_serving_embedding: OpenAIServingEmbedding
|
||||||
|
openai_serving_tokenization: OpenAIServingTokenization
|
||||||
|
|
||||||
|
logger = init_logger('vllm.entrypoints.openai.api_server')
|
||||||
|
|
||||||
|
_running_tasks: Set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
|
||||||
|
def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
|
||||||
|
return ModelConfig(model=model_name,
|
||||||
|
tokenizer=model_name,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
seed=0,
|
||||||
|
dtype="float16").embedding_mode
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: fastapi.FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
|
||||||
async def _force_log():
|
async def _force_log():
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
await engine.do_log_stats()
|
await async_engine_client.do_log_stats()
|
||||||
|
|
||||||
if not engine_args.disable_log_stats:
|
if not engine_args.disable_log_stats:
|
||||||
asyncio.create_task(_force_log())
|
task = asyncio.create_task(_force_log())
|
||||||
|
_running_tasks.add(task)
|
||||||
|
task.add_done_callback(_running_tasks.remove)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
app = fastapi.FastAPI(lifespan=lifespan)
|
@asynccontextmanager
|
||||||
|
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
|
||||||
|
# Context manager to handle async_engine_client lifecycle
|
||||||
|
# Ensures everything is shutdown and cleaned up on error/exit
|
||||||
|
global engine_args
|
||||||
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
|
|
||||||
|
# Backend itself still global for the silly lil' health handler
|
||||||
|
global async_engine_client
|
||||||
|
|
||||||
|
# If manually triggered or embedding model, use AsyncLLMEngine in process.
|
||||||
|
# TODO: support embedding model via RPC.
|
||||||
|
if (model_is_embedding(args.model, args.trust_remote_code)
|
||||||
|
or args.disable_frontend_multiprocessing):
|
||||||
|
async_engine_client = AsyncLLMEngine.from_engine_args(
|
||||||
|
engine_args, usage_context=UsageContext.OPENAI_API_SERVER,
|
||||||
|
load_in_low_bit=args.load_in_low_bit)
|
||||||
|
yield async_engine_client
|
||||||
|
return
|
||||||
|
|
||||||
|
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||||
|
else:
|
||||||
|
# Start RPCServer in separate process (holds the AsyncLLMEngine).
|
||||||
|
port = get_open_port(envs.VLLM_RPC_PORT)
|
||||||
|
load_in_low_bit = args.load_in_low_bit
|
||||||
|
rpc_server_process = Process(target=run_rpc_server,
|
||||||
|
args=(engine_args,
|
||||||
|
UsageContext.OPENAI_API_SERVER,
|
||||||
|
port, load_in_low_bit))
|
||||||
|
rpc_server_process.start()
|
||||||
|
|
||||||
|
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
|
||||||
|
async_engine_client = AsyncEngineRPCClient(port)
|
||||||
|
await async_engine_client.setup()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield async_engine_client
|
||||||
|
finally:
|
||||||
|
# Ensure rpc server process was terminated
|
||||||
|
rpc_server_process.terminate()
|
||||||
|
|
||||||
|
# Close all open connections to the backend
|
||||||
|
async_engine_client.close()
|
||||||
|
|
||||||
|
# Wait for server process to join
|
||||||
|
rpc_server_process.join()
|
||||||
|
|
||||||
|
|
||||||
class LoRAParserAction(argparse.Action):
|
router = APIRouter()
|
||||||
|
|
||||||
def __call__(self, parser, namespace, values, option_string=None):
|
|
||||||
lora_list = []
|
|
||||||
for item in values:
|
|
||||||
name, path = item.split('=')
|
|
||||||
lora_list.append(LoRA(name, path))
|
|
||||||
setattr(namespace, self.dest, lora_list)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def mount_metrics(app: FastAPI):
|
||||||
parser = argparse.ArgumentParser(
|
# Add prometheus asgi middleware to route /metrics requests
|
||||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
metrics_route = Mount("/metrics", make_asgi_app())
|
||||||
parser.add_argument("--host", type=str, default=None, help="host name")
|
# Workaround for 307 Redirect for /metrics
|
||||||
parser.add_argument("--port", type=int, default=8000, help="port number")
|
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
|
||||||
parser.add_argument(
|
app.routes.append(metrics_route)
|
||||||
"--uvicorn-log-level",
|
|
||||||
type=str,
|
|
||||||
default="info",
|
|
||||||
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
|
|
||||||
help="log level for uvicorn")
|
|
||||||
parser.add_argument("--allow-credentials",
|
|
||||||
action="store_true",
|
|
||||||
help="allow credentials")
|
|
||||||
parser.add_argument("--allowed-origins",
|
|
||||||
type=json.loads,
|
|
||||||
default=["*"],
|
|
||||||
help="allowed origins")
|
|
||||||
parser.add_argument("--allowed-methods",
|
|
||||||
type=json.loads,
|
|
||||||
default=["*"],
|
|
||||||
help="allowed methods")
|
|
||||||
parser.add_argument("--allowed-headers",
|
|
||||||
type=json.loads,
|
|
||||||
default=["*"],
|
|
||||||
help="allowed headers")
|
|
||||||
parser.add_argument("--api-key",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="If provided, the server will require this key "
|
|
||||||
"to be presented in the header.")
|
|
||||||
parser.add_argument("--served-model-name",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The model name used in the API. If not "
|
|
||||||
"specified, the model name will be the same as "
|
|
||||||
"the huggingface name.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--lora-modules",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
nargs='+',
|
|
||||||
action=LoRAParserAction,
|
|
||||||
help="LoRA module configurations in the format name=path. "
|
|
||||||
"Multiple modules can be specified.")
|
|
||||||
parser.add_argument("--chat-template",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The file path to the chat template, "
|
|
||||||
"or the template in single-line form "
|
|
||||||
"for the specified model")
|
|
||||||
parser.add_argument("--response-role",
|
|
||||||
type=str,
|
|
||||||
default="assistant",
|
|
||||||
help="The role name to return if "
|
|
||||||
"`request.add_generation_prompt=true`.")
|
|
||||||
parser.add_argument("--ssl-keyfile",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The file path to the SSL key file")
|
|
||||||
parser.add_argument("--ssl-certfile",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The file path to the SSL cert file")
|
|
||||||
parser.add_argument("--ssl-ca-certs",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The CA certificates file")
|
|
||||||
parser.add_argument(
|
|
||||||
"--ssl-cert-reqs",
|
|
||||||
type=int,
|
|
||||||
default=int(ssl.CERT_NONE),
|
|
||||||
help="Whether client certificate is required (see stdlib ssl module's)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--root-path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="FastAPI root_path when app is behind a path based routing proxy")
|
|
||||||
parser.add_argument(
|
|
||||||
"--middleware",
|
|
||||||
type=str,
|
|
||||||
action="append",
|
|
||||||
default=[],
|
|
||||||
help="Additional ASGI middleware to apply to the app. "
|
|
||||||
"We accept multiple --middleware arguments. "
|
|
||||||
"The value should be an import path. "
|
|
||||||
"If a function is provided, vLLM will add it to the server "
|
|
||||||
"using @app.middleware('http'). "
|
|
||||||
"If a class is provided, vLLM will add it to the server "
|
|
||||||
"using app.add_middleware(). ")
|
|
||||||
parser.add_argument(
|
|
||||||
"--load-in-low-bit",
|
|
||||||
type=str,
|
|
||||||
default="sym_int4",
|
|
||||||
help="Low-bit quantization for IPEX-LLM models")
|
|
||||||
|
|
||||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
# Add prometheus asgi middleware to route /metrics requests
|
@router.get("/health")
|
||||||
metrics_app = make_asgi_app()
|
|
||||||
app.mount("/metrics", metrics_app)
|
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
|
||||||
async def validation_exception_handler(_, exc):
|
|
||||||
err = openai_serving_chat.create_error_response(message=str(exc))
|
|
||||||
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health() -> Response:
|
async def health() -> Response:
|
||||||
"""Health check."""
|
"""Health check."""
|
||||||
await openai_serving_chat.engine.check_health()
|
await async_engine_client.check_health()
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
@router.post("/tokenize")
|
||||||
|
async def tokenize(request: TokenizeRequest):
|
||||||
|
generator = await openai_serving_tokenization.create_tokenize(request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
else:
|
||||||
|
assert isinstance(generator, TokenizeResponse)
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/detokenize")
|
||||||
|
async def detokenize(request: DetokenizeRequest):
|
||||||
|
generator = await openai_serving_tokenization.create_detokenize(request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
else:
|
||||||
|
assert isinstance(generator, DetokenizeResponse)
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/models")
|
||||||
async def show_available_models():
|
async def show_available_models():
|
||||||
models = await openai_serving_chat.show_available_models()
|
models = await openai_serving_completion.show_available_models()
|
||||||
return JSONResponse(content=models.model_dump())
|
return JSONResponse(content=models.model_dump())
|
||||||
|
|
||||||
|
|
||||||
@app.get("/version")
|
@router.get("/version")
|
||||||
async def show_version():
|
async def show_version():
|
||||||
ver = {"version": vllm.__version__}
|
ver = {"version": VLLM_VERSION}
|
||||||
return JSONResponse(content=ver)
|
return JSONResponse(content=ver)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@router.post("/v1/chat/completions")
|
||||||
async def create_chat_completion(request: ChatCompletionRequest,
|
async def create_chat_completion(request: ChatCompletionRequest,
|
||||||
raw_request: Request):
|
raw_request: Request):
|
||||||
generator = await openai_serving_chat.create_chat_completion(
|
generator = await openai_serving_chat.create_chat_completion(
|
||||||
|
|
@ -205,10 +197,11 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||||
return StreamingResponse(content=generator,
|
return StreamingResponse(content=generator,
|
||||||
media_type="text/event-stream")
|
media_type="text/event-stream")
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(generator, ChatCompletionResponse)
|
||||||
return JSONResponse(content=generator.model_dump())
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@router.post("/v1/completions")
|
||||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||||
generator = await openai_serving_completion.create_completion(
|
generator = await openai_serving_completion.create_completion(
|
||||||
request, raw_request)
|
request, raw_request)
|
||||||
|
|
@ -222,8 +215,23 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||||
return JSONResponse(content=generator.model_dump())
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
@router.post("/v1/embeddings")
|
||||||
args = parse_args()
|
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||||
|
generator = await openai_serving_embedding.create_embedding(
|
||||||
|
request, raw_request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
else:
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
def build_app(args: Namespace) -> FastAPI:
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.include_router(router)
|
||||||
|
app.root_path = args.root_path
|
||||||
|
|
||||||
|
mount_metrics(app)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
|
|
@ -233,11 +241,20 @@ if __name__ == "__main__":
|
||||||
allow_headers=args.allowed_headers,
|
allow_headers=args.allowed_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
token = os.environ.get("VLLM_API_KEY") or args.api_key
|
@app.exception_handler(RequestValidationError)
|
||||||
if token:
|
async def validation_exception_handler(_, exc):
|
||||||
|
err = openai_serving_chat.create_error_response(message=str(exc))
|
||||||
|
return JSONResponse(err.model_dump(),
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
if token := envs.VLLM_API_KEY or args.api_key:
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def authentication(request: Request, call_next):
|
async def authentication(request: Request, call_next):
|
||||||
if not request.url.path.startswith("/v1"):
|
root_path = "" if args.root_path is None else args.root_path
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
return await call_next(request)
|
||||||
|
if not request.url.path.startswith(f"{root_path}/v1"):
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
if request.headers.get("Authorization") != "Bearer " + token:
|
if request.headers.get("Authorization") != "Bearer " + token:
|
||||||
return JSONResponse(content={"error": "Unauthorized"},
|
return JSONResponse(content={"error": "Unauthorized"},
|
||||||
|
|
@ -252,34 +269,104 @@ if __name__ == "__main__":
|
||||||
elif inspect.iscoroutinefunction(imported):
|
elif inspect.iscoroutinefunction(imported):
|
||||||
app.middleware("http")(imported)
|
app.middleware("http")(imported)
|
||||||
else:
|
else:
|
||||||
invalidInputError(False, (f"Invalid middleware {middleware}. "
|
raise ValueError(f"Invalid middleware {middleware}. "
|
||||||
f"Must be a function or a class."))
|
f"Must be a function or a class.")
|
||||||
|
|
||||||
logger.info(f"vLLM API server version {vllm.__version__}")
|
return app
|
||||||
logger.info(f"args: {args}")
|
|
||||||
|
|
||||||
|
async def init_app(
|
||||||
|
async_engine_client: AsyncEngineClient,
|
||||||
|
args: Namespace,
|
||||||
|
) -> FastAPI:
|
||||||
|
app = build_app(args)
|
||||||
|
|
||||||
if args.served_model_name is not None:
|
if args.served_model_name is not None:
|
||||||
served_model = args.served_model_name
|
served_model_names = args.served_model_name
|
||||||
else:
|
else:
|
||||||
served_model = args.model
|
served_model_names = [args.model]
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
model_config = await async_engine_client.get_model_config()
|
||||||
engine = IPEXLLMAsyncLLMEngine.from_engine_args(engine_args,
|
|
||||||
load_in_low_bit=args.load_in_low_bit)
|
if args.disable_log_requests:
|
||||||
openai_serving_chat = OpenAIServingChat(engine, served_model,
|
request_logger = None
|
||||||
args.response_role,
|
else:
|
||||||
args.lora_modules,
|
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||||
args.chat_template)
|
|
||||||
|
global openai_serving_chat
|
||||||
|
global openai_serving_completion
|
||||||
|
global openai_serving_embedding
|
||||||
|
global openai_serving_tokenization
|
||||||
|
|
||||||
|
openai_serving_chat = OpenAIServingChat(
|
||||||
|
async_engine_client,
|
||||||
|
model_config,
|
||||||
|
served_model_names,
|
||||||
|
args.response_role,
|
||||||
|
lora_modules=args.lora_modules,
|
||||||
|
prompt_adapters=args.prompt_adapters,
|
||||||
|
request_logger=request_logger,
|
||||||
|
chat_template=args.chat_template,
|
||||||
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
|
)
|
||||||
openai_serving_completion = OpenAIServingCompletion(
|
openai_serving_completion = OpenAIServingCompletion(
|
||||||
engine, served_model, args.lora_modules)
|
async_engine_client,
|
||||||
|
model_config,
|
||||||
|
served_model_names,
|
||||||
|
lora_modules=args.lora_modules,
|
||||||
|
prompt_adapters=args.prompt_adapters,
|
||||||
|
request_logger=request_logger,
|
||||||
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
|
)
|
||||||
|
openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
|
async_engine_client,
|
||||||
|
model_config,
|
||||||
|
served_model_names,
|
||||||
|
request_logger=request_logger,
|
||||||
|
)
|
||||||
|
openai_serving_tokenization = OpenAIServingTokenization(
|
||||||
|
async_engine_client,
|
||||||
|
model_config,
|
||||||
|
served_model_names,
|
||||||
|
lora_modules=args.lora_modules,
|
||||||
|
request_logger=request_logger,
|
||||||
|
chat_template=args.chat_template,
|
||||||
|
)
|
||||||
app.root_path = args.root_path
|
app.root_path = args.root_path
|
||||||
uvicorn.run(app,
|
|
||||||
host=args.host,
|
return app
|
||||||
port=args.port,
|
|
||||||
log_level=args.uvicorn_log_level,
|
|
||||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||||
ssl_keyfile=args.ssl_keyfile,
|
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||||
ssl_certfile=args.ssl_certfile,
|
logger.info("args: %s", args)
|
||||||
ssl_ca_certs=args.ssl_ca_certs,
|
|
||||||
ssl_cert_reqs=args.ssl_cert_reqs)
|
async with build_async_engine_client(args) as async_engine_client:
|
||||||
|
app = await init_app(async_engine_client, args)
|
||||||
|
|
||||||
|
shutdown_task = await serve_http(
|
||||||
|
app,
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
log_level=args.uvicorn_log_level,
|
||||||
|
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||||
|
ssl_keyfile=args.ssl_keyfile,
|
||||||
|
ssl_certfile=args.ssl_certfile,
|
||||||
|
ssl_ca_certs=args.ssl_ca_certs,
|
||||||
|
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||||
|
**uvicorn_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NB: Await server shutdown only after the backend context is exited
|
||||||
|
await shutdown_task
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# NOTE(simon):
|
||||||
|
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||||
|
parser = make_arg_parser(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
asyncio.run(run_server(args))
|
||||||
|
|
|
||||||
163
python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py
Normal file
163
python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py
Normal file
|
|
@ -0,0 +1,163 @@
|
||||||
|
"""
|
||||||
|
This file contains the command line arguments for the vLLM's
|
||||||
|
OpenAI-compatible server. It is kept in a separate file for documentation
|
||||||
|
purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import ssl
|
||||||
|
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||||
|
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||||
|
PromptAdapterPath)
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAParserAction(argparse.Action):
|
||||||
|
|
||||||
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
|
lora_list = []
|
||||||
|
for item in values:
|
||||||
|
name, path = item.split('=')
|
||||||
|
lora_list.append(LoRAModulePath(name, path))
|
||||||
|
setattr(namespace, self.dest, lora_list)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptAdapterParserAction(argparse.Action):
|
||||||
|
|
||||||
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
|
adapter_list = []
|
||||||
|
for item in values:
|
||||||
|
name, path = item.split('=')
|
||||||
|
adapter_list.append(PromptAdapterPath(name, path))
|
||||||
|
setattr(namespace, self.dest, adapter_list)
|
||||||
|
|
||||||
|
|
||||||
|
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||||
|
parser.add_argument("--host",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="host name")
|
||||||
|
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||||
|
parser.add_argument(
|
||||||
|
"--uvicorn-log-level",
|
||||||
|
type=str,
|
||||||
|
default="info",
|
||||||
|
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
|
||||||
|
help="log level for uvicorn")
|
||||||
|
parser.add_argument("--allow-credentials",
|
||||||
|
action="store_true",
|
||||||
|
help="allow credentials")
|
||||||
|
parser.add_argument("--allowed-origins",
|
||||||
|
type=json.loads,
|
||||||
|
default=["*"],
|
||||||
|
help="allowed origins")
|
||||||
|
parser.add_argument("--allowed-methods",
|
||||||
|
type=json.loads,
|
||||||
|
default=["*"],
|
||||||
|
help="allowed methods")
|
||||||
|
parser.add_argument("--allowed-headers",
|
||||||
|
type=json.loads,
|
||||||
|
default=["*"],
|
||||||
|
help="allowed headers")
|
||||||
|
parser.add_argument("--api-key",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="If provided, the server will require this key "
|
||||||
|
"to be presented in the header.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-modules",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
nargs='+',
|
||||||
|
action=LoRAParserAction,
|
||||||
|
help="LoRA module configurations in the format name=path. "
|
||||||
|
"Multiple modules can be specified.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-adapters",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
nargs='+',
|
||||||
|
action=PromptAdapterParserAction,
|
||||||
|
help="Prompt adapter configurations in the format name=path. "
|
||||||
|
"Multiple adapters can be specified.")
|
||||||
|
parser.add_argument("--chat-template",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="The file path to the chat template, "
|
||||||
|
"or the template in single-line form "
|
||||||
|
"for the specified model")
|
||||||
|
parser.add_argument("--response-role",
|
||||||
|
type=nullable_str,
|
||||||
|
default="assistant",
|
||||||
|
help="The role name to return if "
|
||||||
|
"`request.add_generation_prompt=true`.")
|
||||||
|
parser.add_argument("--ssl-keyfile",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="The file path to the SSL key file")
|
||||||
|
parser.add_argument("--ssl-certfile",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="The file path to the SSL cert file")
|
||||||
|
parser.add_argument("--ssl-ca-certs",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="The CA certificates file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ssl-cert-reqs",
|
||||||
|
type=int,
|
||||||
|
default=int(ssl.CERT_NONE),
|
||||||
|
help="Whether client certificate is required (see stdlib ssl module's)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--root-path",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="FastAPI root_path when app is behind a path based routing proxy")
|
||||||
|
parser.add_argument(
|
||||||
|
"--middleware",
|
||||||
|
type=nullable_str,
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="Additional ASGI middleware to apply to the app. "
|
||||||
|
"We accept multiple --middleware arguments. "
|
||||||
|
"The value should be an import path. "
|
||||||
|
"If a function is provided, vLLM will add it to the server "
|
||||||
|
"using @app.middleware('http'). "
|
||||||
|
"If a class is provided, vLLM will add it to the server "
|
||||||
|
"using app.add_middleware(). ")
|
||||||
|
parser.add_argument(
|
||||||
|
"--return-tokens-as-token-ids",
|
||||||
|
action="store_true",
|
||||||
|
help="When --max-logprobs is specified, represents single tokens as "
|
||||||
|
"strings of the form 'token_id:{token_id}' so that tokens that "
|
||||||
|
"are not JSON-encodable can be identified.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-frontend-multiprocessing",
|
||||||
|
action="store_true",
|
||||||
|
help="If specified, will run the OpenAI frontend server in the same "
|
||||||
|
"process as the model serving engine.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--load-in-low-bit",
|
||||||
|
type=str,
|
||||||
|
default="sym_int4",
|
||||||
|
help="Low-bit quantization for IPEX-LLM models")
|
||||||
|
|
||||||
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||||
|
|
||||||
|
parser.add_argument('--max-log-len',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='Max number of prompt characters or prompt '
|
||||||
|
'ID numbers being printed in log.'
|
||||||
|
'\n\nDefault: Unlimited')
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def create_parser_for_docs() -> FlexibleArgumentParser:
|
||||||
|
parser_for_docs = FlexibleArgumentParser(
|
||||||
|
prog="-m vllm.entrypoints.openai.api_server")
|
||||||
|
return make_arg_parser(parser_for_docs)
|
||||||
|
|
@ -0,0 +1,221 @@
|
||||||
|
import asyncio
|
||||||
|
import signal
|
||||||
|
from typing import Any, Coroutine
|
||||||
|
|
||||||
|
import cloudpickle
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio
|
||||||
|
from typing_extensions import Never
|
||||||
|
|
||||||
|
from vllm import AsyncEngineArgs
|
||||||
|
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR,
|
||||||
|
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||||
|
RPCGenerateRequest, RPCUtilityRequest)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncEngineRPCServer:
|
||||||
|
|
||||||
|
def __init__(self, async_engine_args: AsyncEngineArgs,
|
||||||
|
usage_context: UsageContext, port: int, load_in_low_bit: str):
|
||||||
|
# Initialize engine first.
|
||||||
|
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
|
||||||
|
usage_context=usage_context,
|
||||||
|
load_in_low_bit=load_in_low_bit)
|
||||||
|
|
||||||
|
# Initialize context.
|
||||||
|
self.context = zmq.asyncio.Context()
|
||||||
|
|
||||||
|
# Init socket for readiness state.
|
||||||
|
self.socket = self.context.socket(zmq.constants.ROUTER)
|
||||||
|
# Note numeric form of localhost should be used for zmq bind(),
|
||||||
|
# see https://stackoverflow.com/a/8958414
|
||||||
|
self.socket.bind(f"tcp://127.0.0.1:{port}")
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup all resources."""
|
||||||
|
self.socket.close()
|
||||||
|
self.context.destroy()
|
||||||
|
|
||||||
|
async def get_model_config(self, identity):
|
||||||
|
"""Send the ModelConfig"""
|
||||||
|
model_config = await self.engine.get_model_config()
|
||||||
|
|
||||||
|
await self.socket.send_multipart(
|
||||||
|
[identity, cloudpickle.dumps(model_config)])
|
||||||
|
|
||||||
|
async def get_decoding_config(self, identity):
|
||||||
|
"""Send the DecodingConfig"""
|
||||||
|
decoding_config = await self.engine.get_decoding_config()
|
||||||
|
|
||||||
|
await self.socket.send_multipart(
|
||||||
|
[identity, cloudpickle.dumps(decoding_config)])
|
||||||
|
|
||||||
|
async def get_lora_config(self, identity):
|
||||||
|
lora_config = await self.engine.get_lora_config()
|
||||||
|
|
||||||
|
await self.socket.send_multipart(
|
||||||
|
[identity, cloudpickle.dumps(lora_config)])
|
||||||
|
|
||||||
|
async def get_scheduler_config(self, identity):
|
||||||
|
"""Send the SchedulerConfig"""
|
||||||
|
parallel_config = await self.engine.get_scheduler_config()
|
||||||
|
|
||||||
|
await self.socket.send_multipart(
|
||||||
|
[identity, cloudpickle.dumps(parallel_config)])
|
||||||
|
|
||||||
|
async def get_parallel_config(self, identity):
|
||||||
|
"""Send the ParallelConfig"""
|
||||||
|
parallel_config = await self.engine.get_parallel_config()
|
||||||
|
|
||||||
|
await self.socket.send_multipart(
|
||||||
|
[identity, cloudpickle.dumps(parallel_config)])
|
||||||
|
|
||||||
|
async def is_tracing_enabled(self, identity):
|
||||||
|
"""Send the is_tracing_enabled flag"""
|
||||||
|
tracing_flag = await self.engine.is_tracing_enabled()
|
||||||
|
|
||||||
|
await self.socket.send_multipart(
|
||||||
|
[identity, cloudpickle.dumps(tracing_flag)])
|
||||||
|
|
||||||
|
async def do_log_stats(self, identity):
|
||||||
|
"""Log stats and confirm success."""
|
||||||
|
await self.engine.do_log_stats()
|
||||||
|
|
||||||
|
await self.socket.send_multipart([
|
||||||
|
identity,
|
||||||
|
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||||
|
])
|
||||||
|
|
||||||
|
async def is_server_ready(self, identity):
|
||||||
|
"""Notify the client that we are ready."""
|
||||||
|
await self.socket.send_multipart([
|
||||||
|
identity,
|
||||||
|
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||||
|
])
|
||||||
|
|
||||||
|
async def abort(self, identity, request: RPCAbortRequest):
|
||||||
|
"""Abort request and notify the client of success."""
|
||||||
|
# Abort the request in the llm engine.
|
||||||
|
await self.engine.abort(request.request_id)
|
||||||
|
|
||||||
|
# Send confirmation to the client.
|
||||||
|
await self.socket.send_multipart([
|
||||||
|
identity,
|
||||||
|
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||||
|
])
|
||||||
|
|
||||||
|
async def generate(self, identity, generate_request: RPCGenerateRequest):
|
||||||
|
try:
|
||||||
|
results_generator = self.engine.generate(
|
||||||
|
generate_request.inputs,
|
||||||
|
sampling_params=generate_request.sampling_params,
|
||||||
|
request_id=generate_request.request_id,
|
||||||
|
lora_request=generate_request.lora_request,
|
||||||
|
trace_headers=generate_request.trace_headers,
|
||||||
|
prompt_adapter_request=generate_request.prompt_adapter_request)
|
||||||
|
|
||||||
|
async for request_output in results_generator:
|
||||||
|
await self.socket.send_multipart(
|
||||||
|
[identity, cloudpickle.dumps(request_output)])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Notify client of all failures
|
||||||
|
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
||||||
|
|
||||||
|
async def check_health(self, identity):
|
||||||
|
try:
|
||||||
|
await self.engine.check_health()
|
||||||
|
await self.socket.send_multipart(
|
||||||
|
[identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)])
|
||||||
|
except Exception as e:
|
||||||
|
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
||||||
|
|
||||||
|
def _make_handler_coro(self, identity,
|
||||||
|
message) -> Coroutine[Any, Any, Never]:
|
||||||
|
"""Route the zmq message to the handler coroutine."""
|
||||||
|
|
||||||
|
request = cloudpickle.loads(message)
|
||||||
|
|
||||||
|
if isinstance(request, RPCGenerateRequest):
|
||||||
|
return self.generate(identity, request)
|
||||||
|
|
||||||
|
elif isinstance(request, RPCAbortRequest):
|
||||||
|
return self.abort(identity, request)
|
||||||
|
|
||||||
|
elif isinstance(request, RPCUtilityRequest):
|
||||||
|
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
|
||||||
|
return self.get_model_config(identity)
|
||||||
|
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
|
||||||
|
return self.get_parallel_config(identity)
|
||||||
|
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
|
||||||
|
return self.get_decoding_config(identity)
|
||||||
|
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
|
||||||
|
return self.get_scheduler_config(identity)
|
||||||
|
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
|
||||||
|
return self.get_lora_config(identity)
|
||||||
|
elif request == RPCUtilityRequest.DO_LOG_STATS:
|
||||||
|
return self.do_log_stats(identity)
|
||||||
|
elif request == RPCUtilityRequest.IS_SERVER_READY:
|
||||||
|
return self.is_server_ready(identity)
|
||||||
|
elif request == RPCUtilityRequest.CHECK_HEALTH:
|
||||||
|
return self.check_health(identity)
|
||||||
|
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
|
||||||
|
return self.is_tracing_enabled(identity)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown RPCUtilityRequest type: {request}") # noqa
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown RPCRequest type: {request}") # noqa
|
||||||
|
|
||||||
|
async def run_server_loop(self):
|
||||||
|
"""Inner RPC Server Loop"""
|
||||||
|
|
||||||
|
running_tasks = set()
|
||||||
|
while True:
|
||||||
|
# Wait for a request.
|
||||||
|
identity, message = await self.socket.recv_multipart()
|
||||||
|
|
||||||
|
# Process the request async.
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._make_handler_coro(identity, message))
|
||||||
|
|
||||||
|
# We need to keep around a strong reference to the task,
|
||||||
|
# to avoid the task disappearing mid-execution as running tasks
|
||||||
|
# can be GC'ed. Below is a common "fire-and-forget" tasks
|
||||||
|
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
|
||||||
|
running_tasks.add(task)
|
||||||
|
task.add_done_callback(running_tasks.discard)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_server(server: AsyncEngineRPCServer):
|
||||||
|
# Put the server task into the asyncio loop.
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
server_task = loop.create_task(server.run_server_loop())
|
||||||
|
|
||||||
|
# Interruption handling.
|
||||||
|
def signal_handler() -> None:
|
||||||
|
# Kill the server on interrupt / terminate
|
||||||
|
server_task.cancel()
|
||||||
|
|
||||||
|
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||||
|
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await server_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("vLLM ZMQ RPC Server was interrupted.")
|
||||||
|
finally:
|
||||||
|
# Clean up all resources.
|
||||||
|
server.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def run_rpc_server(async_engine_args: AsyncEngineArgs,
|
||||||
|
usage_context: UsageContext, port: int, load_in_low_bit: str):
|
||||||
|
server = AsyncEngineRPCServer(async_engine_args, usage_context, port, load_in_low_bit)
|
||||||
|
asyncio.run(run_server(server))
|
||||||
|
|
@ -1,466 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import copy
|
|
||||||
from collections import defaultdict
|
|
||||||
import os
|
|
||||||
import pickle
|
|
||||||
import importlib
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
|
||||||
ParallelConfig, SchedulerConfig, LoRAConfig)
|
|
||||||
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
|
||||||
from vllm.executor.utils import check_block_size_valid
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.lora.request import LoRARequest
|
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
|
||||||
from vllm.utils import (set_cuda_visible_devices, get_ip, get_open_port,
|
|
||||||
get_distributed_init_method, make_async)
|
|
||||||
import functools
|
|
||||||
from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert
|
|
||||||
from ipex_llm.utils.common import invalidInputError
|
|
||||||
|
|
||||||
if ray is not None:
|
|
||||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ray.util.placement_group import PlacementGroup
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
# A map between the device type (in device config) to its worker module.
|
|
||||||
DEVICE_TO_WORKER_MODULE_MAP = {
|
|
||||||
"cuda": "vllm.worker.worker",
|
|
||||||
"xpu": "vllm.worker.worker",
|
|
||||||
"neuron": "vllm.worker.neuron_worker",
|
|
||||||
}
|
|
||||||
|
|
||||||
# If the env var is set, it uses the Ray's compiled DAG API
|
|
||||||
# which optimizes the control plane overhead.
|
|
||||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
|
||||||
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
|
|
||||||
|
|
||||||
|
|
||||||
class IPEXLLMGPUExecutor(ExecutorBase):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
cache_config: CacheConfig,
|
|
||||||
parallel_config: ParallelConfig,
|
|
||||||
scheduler_config: SchedulerConfig,
|
|
||||||
device_config: DeviceConfig,
|
|
||||||
lora_config: Optional[LoRAConfig],
|
|
||||||
load_in_low_bit: str,
|
|
||||||
) -> None:
|
|
||||||
self.model_config = model_config
|
|
||||||
self.cache_config = cache_config
|
|
||||||
self.lora_config = lora_config
|
|
||||||
self.parallel_config = parallel_config
|
|
||||||
self.scheduler_config = scheduler_config
|
|
||||||
self.device_config = device_config
|
|
||||||
self.load_in_low_bit = load_in_low_bit
|
|
||||||
|
|
||||||
invalidInputError(self.parallel_config.worker_use_ray,
|
|
||||||
"worker_use_ray is False, but use ray worker")
|
|
||||||
placement_group = self.parallel_config.placement_group
|
|
||||||
|
|
||||||
# Disable Ray usage stats collection.
|
|
||||||
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
|
||||||
if ray_usage != "1":
|
|
||||||
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
|
||||||
|
|
||||||
# Create the parallel GPU workers.
|
|
||||||
self._init_workers_ray(placement_group)
|
|
||||||
|
|
||||||
# Profile the memory usage and initialize the cache.
|
|
||||||
self._init_cache()
|
|
||||||
|
|
||||||
self.forward_dag = None
|
|
||||||
if USE_RAY_COMPILED_DAG:
|
|
||||||
self.forward_dag = self._compiled_ray_dag()
|
|
||||||
|
|
||||||
def _dispatch_worker(self):
|
|
||||||
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
|
|
||||||
self.device_config.device_type]
|
|
||||||
imported_worker = importlib.import_module(worker_module)
|
|
||||||
Worker = imported_worker.Worker
|
|
||||||
return Worker
|
|
||||||
|
|
||||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
|
||||||
**ray_remote_kwargs):
|
|
||||||
if self.parallel_config.tensor_parallel_size == 1:
|
|
||||||
# For single GPU case, we use a ray worker with constrained memory.
|
|
||||||
num_gpus = self.cache_config.gpu_memory_utilization
|
|
||||||
else:
|
|
||||||
# Otherwise, the ray workers are allocated with a full GPU.
|
|
||||||
num_gpus = 1
|
|
||||||
|
|
||||||
# The driver dummy worker does not actually use any resources.
|
|
||||||
# It holds the resource for the driver worker.
|
|
||||||
self.driver_dummy_worker: RayWorkerVllm = None
|
|
||||||
# The remaining workers are the actual ray actors.
|
|
||||||
self.workers: List[RayWorkerVllm] = []
|
|
||||||
|
|
||||||
# Create the workers.
|
|
||||||
driver_ip = get_ip()
|
|
||||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
|
||||||
if not bundle.get("GPU", 0):
|
|
||||||
continue
|
|
||||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
|
||||||
placement_group=placement_group,
|
|
||||||
placement_group_capture_child_tasks=True,
|
|
||||||
placement_group_bundle_index=bundle_id,
|
|
||||||
)
|
|
||||||
worker = ray.remote(
|
|
||||||
num_cpus=0,
|
|
||||||
num_gpus=num_gpus,
|
|
||||||
scheduling_strategy=scheduling_strategy,
|
|
||||||
**ray_remote_kwargs,
|
|
||||||
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
|
|
||||||
|
|
||||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
|
||||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
|
||||||
# If the worker is on the same node as the driver, we use it
|
|
||||||
# as the resource holder for the driver process.
|
|
||||||
self.driver_dummy_worker = worker
|
|
||||||
else:
|
|
||||||
# Else, added to the list of workers.
|
|
||||||
self.workers.append(worker)
|
|
||||||
|
|
||||||
if self.driver_dummy_worker is None:
|
|
||||||
invalidInputError(False,
|
|
||||||
"Ray does not allocate any GPUs on the driver node. Consider "
|
|
||||||
"adjusting the Ray placement group or running the driver on a "
|
|
||||||
"GPU node.")
|
|
||||||
# Get the set of GPU IDs used on each node.
|
|
||||||
driver_node_id, driver_gpu_ids = ray.get(
|
|
||||||
self.driver_dummy_worker.get_node_and_gpu_ids.remote())
|
|
||||||
worker_node_and_gpu_ids = ray.get(
|
|
||||||
[worker.get_node_and_gpu_ids.remote() for worker in self.workers])
|
|
||||||
|
|
||||||
node_workers = defaultdict(list)
|
|
||||||
node_gpus = defaultdict(list)
|
|
||||||
|
|
||||||
node_workers[driver_node_id].append(0)
|
|
||||||
node_gpus[driver_node_id].extend(driver_gpu_ids)
|
|
||||||
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
|
|
||||||
start=1):
|
|
||||||
node_workers[node_id].append(i)
|
|
||||||
node_gpus[node_id].extend(gpu_ids)
|
|
||||||
for node_id, gpu_ids in node_gpus.items():
|
|
||||||
node_gpus[node_id] = sorted(gpu_ids)
|
|
||||||
|
|
||||||
# Set CUDA_VISIBLE_DEVICES for the driver and workers.
|
|
||||||
set_cuda_visible_devices(node_gpus[driver_node_id])
|
|
||||||
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
|
|
||||||
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
|
|
||||||
|
|
||||||
distributed_init_method = get_distributed_init_method(
|
|
||||||
driver_ip, get_open_port())
|
|
||||||
|
|
||||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
|
||||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
|
||||||
Worker = self._dispatch_worker()
|
|
||||||
|
|
||||||
model_config = copy.deepcopy(self.model_config)
|
|
||||||
parallel_config = copy.deepcopy(self.parallel_config)
|
|
||||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
|
||||||
device_config = copy.deepcopy(self.device_config)
|
|
||||||
lora_config = copy.deepcopy(self.lora_config)
|
|
||||||
kv_cache_dtype = self.cache_config.cache_dtype
|
|
||||||
|
|
||||||
# Initialize the actual workers with the Worker class.
|
|
||||||
for rank, (worker, (node_id, _)) in enumerate(
|
|
||||||
zip(self.workers, worker_node_and_gpu_ids),
|
|
||||||
start=1,
|
|
||||||
):
|
|
||||||
local_rank = node_workers[node_id].index(rank)
|
|
||||||
|
|
||||||
def create_worker_function(rank, local_rank, load_in_low_bit):
|
|
||||||
def worker_function():
|
|
||||||
_ipex_llm_convert(load_in_low_bit)
|
|
||||||
return Worker(
|
|
||||||
model_config,
|
|
||||||
parallel_config,
|
|
||||||
scheduler_config,
|
|
||||||
device_config,
|
|
||||||
local_rank,
|
|
||||||
rank,
|
|
||||||
distributed_init_method,
|
|
||||||
lora_config=lora_config,
|
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
|
||||||
)
|
|
||||||
return worker_function
|
|
||||||
worker.init_worker.remote(create_worker_function(rank,
|
|
||||||
local_rank,
|
|
||||||
self.load_in_low_bit))
|
|
||||||
|
|
||||||
# Initialize the driver worker with the Worker class.
|
|
||||||
driver_rank = 0
|
|
||||||
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
|
|
||||||
self.driver_worker = Worker(
|
|
||||||
self.model_config,
|
|
||||||
self.parallel_config,
|
|
||||||
self.scheduler_config,
|
|
||||||
self.device_config,
|
|
||||||
driver_local_rank,
|
|
||||||
driver_rank,
|
|
||||||
distributed_init_method,
|
|
||||||
lora_config=self.lora_config,
|
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
|
||||||
is_driver_worker=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# We want to apply patch here before we loading the model
|
|
||||||
# FIXME(woosuk): We are not properly initializing cupy NCCL when
|
|
||||||
# we have multiple nodes.
|
|
||||||
self._run_workers("init_model",
|
|
||||||
cupy_port=get_open_port()
|
|
||||||
if not model_config.enforce_eager else None)
|
|
||||||
self._run_workers(
|
|
||||||
"load_model",
|
|
||||||
max_concurrent_workers=self.parallel_config.
|
|
||||||
max_parallel_loading_workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _init_cache(self) -> None:
|
|
||||||
"""Profiles the memory usage and initializes the KV cache.
|
|
||||||
|
|
||||||
The engine will first conduct a profiling of the existing memory usage.
|
|
||||||
Then, it calculate the maximum possible number of GPU and CPU blocks
|
|
||||||
that can be allocated with the remaining free memory.
|
|
||||||
More details can be found in the
|
|
||||||
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
|
|
||||||
from class :class:`~vllm.worker.Worker`.
|
|
||||||
|
|
||||||
Afterwards, as there may be multiple workers,
|
|
||||||
we take the minimum number of blocks across all workers
|
|
||||||
to ensure this can be applied to all of them.
|
|
||||||
|
|
||||||
Finally, the engine will initialize the KV cache
|
|
||||||
with the calculated number of blocks.
|
|
||||||
|
|
||||||
.. tip::
|
|
||||||
You may limit the usage of GPU memory
|
|
||||||
by adjusting the `gpu_memory_utilization` parameter.
|
|
||||||
"""
|
|
||||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
|
||||||
num_blocks = self._run_workers(
|
|
||||||
"profile_num_available_blocks",
|
|
||||||
block_size=self.cache_config.block_size,
|
|
||||||
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
|
||||||
cpu_swap_space=self.cache_config.swap_space_bytes,
|
|
||||||
cache_dtype=self.cache_config.cache_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Since we use a shared centralized controller, we take the minimum
|
|
||||||
# number of blocks across all workers to make sure all the memory
|
|
||||||
# operators can be applied to all workers.
|
|
||||||
num_gpu_blocks = min(b[0] for b in num_blocks)
|
|
||||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
|
||||||
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
|
||||||
f"# CPU blocks: {num_cpu_blocks}")
|
|
||||||
|
|
||||||
check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
|
|
||||||
self.model_config.max_model_len)
|
|
||||||
|
|
||||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
||||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
||||||
|
|
||||||
# Initialize the cache.
|
|
||||||
self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
|
||||||
# Warm up the model. This includes capturing the model into CUDA graph
|
|
||||||
# if enforce_eager is False.
|
|
||||||
self._run_workers("warm_up_model")
|
|
||||||
|
|
||||||
def execute_model(self,
|
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
|
|
||||||
all_outputs = self._run_workers(
|
|
||||||
"execute_model",
|
|
||||||
driver_kwargs={
|
|
||||||
"seq_group_metadata_list": seq_group_metadata_list,
|
|
||||||
"blocks_to_swap_in": blocks_to_swap_in,
|
|
||||||
"blocks_to_swap_out": blocks_to_swap_out,
|
|
||||||
"blocks_to_copy": blocks_to_copy,
|
|
||||||
},
|
|
||||||
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
|
|
||||||
|
|
||||||
# Only the driver worker returns the sampling results.
|
|
||||||
output = all_outputs[0]
|
|
||||||
return output
|
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
||||||
invalidInputError(lora_request.lora_int_id > 0,
|
|
||||||
"lora_id must be greater than 0.")
|
|
||||||
return self._run_workers(
|
|
||||||
"add_lora",
|
|
||||||
lora_request=lora_request,
|
|
||||||
)
|
|
||||||
|
|
||||||
def remove_lora(self, lora_id: int) -> bool:
|
|
||||||
invalidInputError(lora_id > 0, "lora_id must be greater than 0.")
|
|
||||||
return self._run_workers(
|
|
||||||
"remove_lora",
|
|
||||||
lora_id=lora_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def list_loras(self) -> List[int]:
|
|
||||||
return self._run_workers("list_loras")
|
|
||||||
|
|
||||||
def _run_workers(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
*args,
|
|
||||||
driver_args: Optional[List[Any]]=None,
|
|
||||||
driver_kwargs: Optional[Dict[str, Any]]=None,
|
|
||||||
max_concurrent_workers: Optional[int] = None,
|
|
||||||
use_ray_compiled_dag: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> Any:
|
|
||||||
"""Runs the given method on all workers."""
|
|
||||||
|
|
||||||
if max_concurrent_workers:
|
|
||||||
invalidInputError(False,
|
|
||||||
"max_concurrent_workers is not supported yet.")
|
|
||||||
|
|
||||||
if use_ray_compiled_dag:
|
|
||||||
# Right now, compiled DAG can only accept a single
|
|
||||||
# input. TODO(sang): Fix it.
|
|
||||||
output_channels = self.forward_dag.execute(1)
|
|
||||||
else:
|
|
||||||
# Start the ray workers first.
|
|
||||||
ray_worker_outputs = [
|
|
||||||
worker.execute_method.remote(method, *args, **kwargs)
|
|
||||||
for worker in self.workers
|
|
||||||
]
|
|
||||||
|
|
||||||
if driver_args is None:
|
|
||||||
driver_args = args
|
|
||||||
if driver_kwargs is None:
|
|
||||||
driver_kwargs = kwargs
|
|
||||||
|
|
||||||
# Start the driver worker after all the ray workers.
|
|
||||||
driver_worker_output = getattr(self.driver_worker,
|
|
||||||
method)(*driver_args, **driver_kwargs)
|
|
||||||
|
|
||||||
# Get the results of the ray workers.
|
|
||||||
if self.workers:
|
|
||||||
if use_ray_compiled_dag:
|
|
||||||
try:
|
|
||||||
ray_worker_outputs = [
|
|
||||||
pickle.loads(chan.begin_read())
|
|
||||||
for chan in output_channels
|
|
||||||
]
|
|
||||||
finally:
|
|
||||||
# Has to call end_read in order to reuse the DAG.
|
|
||||||
for chan in output_channels:
|
|
||||||
chan.end_read()
|
|
||||||
else:
|
|
||||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
|
||||||
|
|
||||||
return [driver_worker_output] + ray_worker_outputs
|
|
||||||
|
|
||||||
def _compiled_ray_dag(self):
|
|
||||||
import pkg_resources
|
|
||||||
required_version = "2.9"
|
|
||||||
current_version = pkg_resources.get_distribution("ray").version
|
|
||||||
if current_version < required_version:
|
|
||||||
invalidInputError(False,
|
|
||||||
f"Ray version {required_version} or greater is "
|
|
||||||
f"required, but found {current_version}")
|
|
||||||
|
|
||||||
from ray.dag import MultiOutputNode, InputNode
|
|
||||||
invalidInputError(self.parallel_config.worker_use_ray,
|
|
||||||
"Use ray worker, but worker_use_ray is False")
|
|
||||||
|
|
||||||
# Right now, compiled DAG requires at least 1 arg. We send
|
|
||||||
# a dummy value for now. It will be fixed soon.
|
|
||||||
with InputNode() as input_data:
|
|
||||||
forward_dag = MultiOutputNode([
|
|
||||||
worker.execute_model_compiled_dag_remote.bind(input_data)
|
|
||||||
for worker in self.workers
|
|
||||||
])
|
|
||||||
return forward_dag.experimental_compile()
|
|
||||||
|
|
||||||
def check_health(self) -> None:
|
|
||||||
"""Raises an error if engine is unhealthy."""
|
|
||||||
self._check_if_any_actor_is_dead()
|
|
||||||
|
|
||||||
def _check_if_any_actor_is_dead(self):
|
|
||||||
if not self.workers:
|
|
||||||
return
|
|
||||||
|
|
||||||
dead_actors = []
|
|
||||||
for actor in self.workers:
|
|
||||||
actor_state = ray.state.actors(actor._ray_actor_id.hex())
|
|
||||||
if actor_state["State"] == "DEAD":
|
|
||||||
dead_actors.append(actor)
|
|
||||||
if dead_actors:
|
|
||||||
invalidInputError("At least one Worker is dead. "
|
|
||||||
f"Dead Workers: {dead_actors}. ")
|
|
||||||
|
|
||||||
|
|
||||||
class IPEXLLMGPUExecutorAsync(IPEXLLMGPUExecutor, ExecutorAsyncBase):
|
|
||||||
|
|
||||||
async def _run_workers_async(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
*args,
|
|
||||||
driver_args: Optional[List[Any]]=None,
|
|
||||||
driver_kwargs: Optional[Dict[str, Any]]=None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Any:
|
|
||||||
"""Runs the given method on all workers."""
|
|
||||||
coros = []
|
|
||||||
|
|
||||||
if driver_args is None:
|
|
||||||
driver_args = args
|
|
||||||
if driver_kwargs is None:
|
|
||||||
driver_kwargs = kwargs
|
|
||||||
|
|
||||||
# Run the driver worker asynchronously.
|
|
||||||
driver_executor = make_async(getattr(self.driver_worker, method))
|
|
||||||
coros.append(driver_executor(*driver_args, **driver_kwargs))
|
|
||||||
|
|
||||||
# Run the ray workers asynchronously.
|
|
||||||
for worker in self.workers:
|
|
||||||
coros.append(worker.execute_method.remote(method, *args, **kwargs))
|
|
||||||
|
|
||||||
all_outputs = await asyncio.gather(*coros)
|
|
||||||
return all_outputs
|
|
||||||
|
|
||||||
async def execute_model_async(
|
|
||||||
self,
|
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
) -> SamplerOutput:
|
|
||||||
all_outputs = await self._run_workers_async(
|
|
||||||
"execute_model",
|
|
||||||
driver_kwargs={
|
|
||||||
"seq_group_metadata_list": seq_group_metadata_list,
|
|
||||||
"blocks_to_swap_in": blocks_to_swap_in,
|
|
||||||
"blocks_to_swap_out": blocks_to_swap_out,
|
|
||||||
"blocks_to_copy": blocks_to_copy,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Only the driver worker returns the sampling results.
|
|
||||||
output = all_outputs[0]
|
|
||||||
return output
|
|
||||||
|
|
||||||
async def check_health_async(self) -> None:
|
|
||||||
"""Raises an error if engine is unhealthy."""
|
|
||||||
self._check_if_any_actor_is_dead()
|
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_executor_class(load_in_low_bit):
|
|
||||||
return functools.partial(IPEXLLMGPUExecutor, load_in_low_bit=load_in_low_bit)
|
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_executor_class_async(load_in_low_bit):
|
|
||||||
return functools.partial(IPEXLLMGPUExecutorAsync, load_in_low_bit=load_in_low_bit)
|
|
||||||
24
python/llm/src/ipex_llm/vllm/xpu/ipex_llm_wrapper.py
Normal file
24
python/llm/src/ipex_llm/vllm/xpu/ipex_llm_wrapper.py
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.executor.ray_utils import RayWorkerWrapper
|
||||||
|
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class IPEXLLMWrapper(RayWorkerWrapper):
|
||||||
|
def __init__(self, load_in_low_bit="sym_int4", *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert
|
||||||
|
_ipex_llm_convert(load_in_low_bit=load_in_low_bit)
|
||||||
|
self.compiled_dag_cuda_device_set = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_ipex_llm_wrapper(load_in_low_bit):
|
||||||
|
# The reason why we not using functools.partial is that
|
||||||
|
# ray seems not work well with it.
|
||||||
|
class WrapperWithLoadBit(IPEXLLMWrapper):
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(load_in_low_bit=load_in_low_bit, *args, **kwargs)
|
||||||
|
|
||||||
|
# a = functools.partial(IPEXLLMWrapper, load_in_low_bit=load_in_low_bit)
|
||||||
|
return WrapperWithLoadBit
|
||||||
|
|
@ -14,6 +14,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import torch
|
import torch
|
||||||
|
from typing import Optional, Union
|
||||||
|
from vllm.distributed import tensor_model_parallel_gather, tensor_model_parallel_all_gather
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention, LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention, LlamaForCausalLM
|
||||||
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention, Qwen2ForCausalLM
|
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention, Qwen2ForCausalLM
|
||||||
|
|
@ -22,236 +24,78 @@ from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention
|
||||||
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
||||||
from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention, ChatGLMForCausalLM
|
from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention, ChatGLMForCausalLM
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
VocabParallelEmbedding)
|
||||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
|
||||||
from vllm.config import DeviceConfig
|
from vllm.config import DeviceConfig
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from typing import Tuple
|
||||||
from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather
|
from ipex_llm.transformers.low_bit_linear import LowBitLinear
|
||||||
|
|
||||||
from typing import Tuple, Optional, Union
|
|
||||||
from ipex_llm.utils.common import invalidInputError
|
|
||||||
from vllm.sequence import SamplerOutput
|
|
||||||
|
|
||||||
|
|
||||||
def _Llama_sample(
|
def _sample_get_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
lm_head: Union[VocabParallelEmbedding, LowBitLinear],
|
||||||
) -> Optional[SamplerOutput]:
|
embedding_bias: Optional[torch.Tensor],
|
||||||
next_tokens = self.sampler(self.lm_head, hidden_states,
|
) -> torch.Tensor:
|
||||||
sampling_metadata)
|
# HINT: we do not support other types of quantization for now
|
||||||
return next_tokens
|
# TODO: we may encounter tie-word-embedding problems
|
||||||
|
if isinstance(lm_head, VocabParallelEmbedding):
|
||||||
|
logits = lm_head.linear_method.apply(lm_head,
|
||||||
def _Qwen2_sample(
|
hidden_states,
|
||||||
self,
|
bias=embedding_bias)
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[SamplerOutput]:
|
|
||||||
if self.config.tie_word_embeddings:
|
|
||||||
# Embedding layer is not optimized to LowBitLinear
|
|
||||||
lm_head_weight = self.model.embed_tokens.weight
|
|
||||||
else:
|
else:
|
||||||
# This layer is optimized to LowBitLinear
|
logits = lm_head(hidden_states)
|
||||||
lm_head_weight = self.lm_head
|
if embedding_bias is not None:
|
||||||
next_tokens = self.sampler(lm_head_weight, hidden_states,
|
logits += embedding_bias
|
||||||
sampling_metadata)
|
if self.use_gather:
|
||||||
return next_tokens
|
logits = tensor_model_parallel_gather(logits)
|
||||||
|
|
||||||
|
|
||||||
def _Chatglm_sample(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[SamplerOutput]:
|
|
||||||
next_tokens = self.sampler(self.transformer.output_layer, hidden_states,
|
|
||||||
sampling_metadata)
|
|
||||||
|
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def _sample_get_logits(self, hidden_states: torch.Tensor,
|
|
||||||
embedding: Union[torch.nn.Module, torch.Tensor],
|
|
||||||
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
||||||
# For tie_word_embedding models, the embedding is not optimized as
|
|
||||||
# the low_bit_linear layer...
|
|
||||||
if isinstance(embedding, torch.Tensor):
|
|
||||||
logits = torch.matmul(hidden_states, embedding.t())
|
|
||||||
else:
|
else:
|
||||||
logits = embedding(hidden_states)
|
logits = tensor_model_parallel_all_gather(logits)
|
||||||
if embedding_bias is not None:
|
|
||||||
logits += embedding_bias
|
|
||||||
logits = tensor_model_parallel_gather(logits)
|
|
||||||
# Remove paddings in vocab (if any).
|
|
||||||
if logits is not None:
|
if logits is not None:
|
||||||
logits = logits[:, :self.org_vocab_size]
|
logits = logits[:, : self.org_vocab_size]
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _MLP_forward(self, x):
|
|
||||||
gate_up = self.gate_up_proj(x)
|
|
||||||
x = self.act_fn(gate_up)
|
|
||||||
x = self.down_proj(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def _Attention_forward(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
qkv = self.qkv_proj(hidden_states)
|
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
|
||||||
k_cache, v_cache = kv_cache
|
|
||||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
|
||||||
output = self.o_proj(attn_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def _QWen_Attention_forward(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
qkv = self.c_attn(hidden_states)
|
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
|
||||||
k_cache, v_cache = kv_cache
|
|
||||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
|
||||||
output = self.c_proj(attn_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def _QWen_MLP_forward(self, x):
|
|
||||||
gate_up = self.gate_up_proj(x)
|
|
||||||
x = self.act_fn(gate_up)
|
|
||||||
x = self.c_proj(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def _ChatGLM_MLP_forward(self, hidden_states):
|
|
||||||
# [s, b, 4hp]
|
|
||||||
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
|
||||||
intermediate_parallel = self.activation_func(intermediate_parallel)
|
|
||||||
# [s, b, h]
|
|
||||||
output = self.dense_4h_to_h(intermediate_parallel)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def _Baichuan_Attention_forward(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
qkv = self.W_pack(hidden_states)
|
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
|
||||||
if self.postion_embedding != "ALIBI":
|
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
|
||||||
k_cache, v_cache = kv_cache
|
|
||||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
|
||||||
output = self.o_proj(attn_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def _ChatGLM_Attention_forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
position_ids: torch.Tensor,
|
|
||||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
qkv = self.query_key_value(hidden_states)
|
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
||||||
q, k = self.rotary_emb(position_ids, q, k)
|
|
||||||
key_cache, value_cache = kv_cache
|
|
||||||
context_layer = self.attn(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
input_metadata,
|
|
||||||
)
|
|
||||||
attn_output = self.dense(context_layer)
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
_REPLACED_MLP_LAYERS = {
|
|
||||||
LlamaMLP: _MLP_forward,
|
|
||||||
Qwen2MLP: _MLP_forward,
|
|
||||||
BaiChuanMLP: _MLP_forward,
|
|
||||||
QWenMLP: _QWen_MLP_forward,
|
|
||||||
GLMMLP: _ChatGLM_MLP_forward
|
|
||||||
}
|
|
||||||
|
|
||||||
_REPLACED_ATTENTION_LAYERS = {
|
|
||||||
LlamaAttention: _Attention_forward,
|
|
||||||
Qwen2Attention: _Attention_forward,
|
|
||||||
QWenAttention: _QWen_Attention_forward,
|
|
||||||
BaiChuanAttention: _Baichuan_Attention_forward,
|
|
||||||
GLMAttention: _ChatGLM_Attention_forward
|
|
||||||
}
|
|
||||||
|
|
||||||
_REPLACED_SAMPLER_LAYERS = {
|
|
||||||
LlamaForCausalLM: _Llama_sample,
|
|
||||||
QWenLMHeadModel: _Llama_sample,
|
|
||||||
ChatGLMForCausalLM: _Chatglm_sample,
|
|
||||||
Qwen2ForCausalLM: _Qwen2_sample,
|
|
||||||
BaiChuanBaseForCausalLM: _Llama_sample,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _model_mlp_convert():
|
|
||||||
for module, replaced_func in _REPLACED_MLP_LAYERS.items():
|
|
||||||
setattr(module, "forward", replaced_func)
|
|
||||||
|
|
||||||
|
|
||||||
def _model_sample_convert():
|
def _model_sample_convert():
|
||||||
setattr(Sampler, "_get_logits", _sample_get_logits)
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
for module, replaced_func in _REPLACED_SAMPLER_LAYERS.items():
|
setattr(LogitsProcessor, "_get_logits", _sample_get_logits)
|
||||||
setattr(module, "sample", replaced_func)
|
|
||||||
|
|
||||||
|
|
||||||
def _model_attention_convert():
|
|
||||||
for module, replaced_func in _REPLACED_ATTENTION_LAYERS.items():
|
|
||||||
setattr(module, "forward", replaced_func)
|
|
||||||
|
|
||||||
|
|
||||||
def _ipex_llm_convert(load_in_low_bit):
|
def _ipex_llm_convert(load_in_low_bit):
|
||||||
from vllm.worker.model_runner import ModelRunner
|
from vllm.worker.xpu_model_runner import XPUModelRunner
|
||||||
import vllm.model_executor.model_loader as model_loader
|
from ipex_llm.vllm.xpu.ipex_llm_wrapper import get_ipex_llm_wrapper
|
||||||
setattr(ModelRunner, "load_model", get_load_function(load_in_low_bit))
|
import vllm.executor.ray_utils as ray_utils
|
||||||
|
setattr(XPUModelRunner, "load_model", get_load_function(load_in_low_bit))
|
||||||
|
setattr(ray_utils, "RayWorkerWrapper", get_ipex_llm_wrapper(load_in_low_bit))
|
||||||
|
|
||||||
|
|
||||||
def get_load_function(low_bit):
|
def get_load_function(low_bit):
|
||||||
def _ipex_llm_load_model(self) -> None:
|
def _ipex_llm_load_model(self) -> None:
|
||||||
# _model_mlp_convert()
|
|
||||||
# _model_attention_convert()
|
|
||||||
_model_sample_convert()
|
_model_sample_convert()
|
||||||
|
|
||||||
from vllm.utils import measure_device_memory
|
# from vllm.utils import measure_device_memory
|
||||||
with measure_device_memory() as m:
|
from vllm.utils import CudaMemoryProfiler
|
||||||
# only support xpu for now
|
with CudaMemoryProfiler() as m:
|
||||||
# We have to create a new DeviceConfig.
|
self.model = get_model(
|
||||||
# Otherwise, will get the wrong xpu memory usage
|
model_config=self.model_config,
|
||||||
self.model = get_model(self.model_config,
|
device_config=DeviceConfig("cpu"),
|
||||||
DeviceConfig("cpu"),
|
load_config=self.load_config,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
parallel_config=self.parallel_config,
|
multimodal_config=self.multimodal_config,
|
||||||
scheduler_config=self.scheduler_config)
|
parallel_config=self.parallel_config,
|
||||||
|
scheduler_config=self.scheduler_config,
|
||||||
|
cache_config=self.cache_config,
|
||||||
|
)
|
||||||
|
if "qwen" in self.model_config.model.lower() and \
|
||||||
|
self.model.model.layers[0].mlp.down_proj.input_size_per_partition % 256 != 0:
|
||||||
|
self.model.apply(padding_mlp)
|
||||||
from ipex_llm import optimize_model
|
from ipex_llm import optimize_model
|
||||||
import os
|
import os
|
||||||
not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None)
|
not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None)
|
||||||
is_glm4_model = "glm-4" in self.model_config.model.lower()
|
is_glm4_model = "glm-4" in self.model_config.model.lower()
|
||||||
if not_convert_last_mlp is not None or is_glm4_model:
|
is_codegeex4_model = "codegeex4-all" in self.model_config.model.lower()
|
||||||
|
if not_convert_last_mlp is not None or is_glm4_model or is_codegeex4_model:
|
||||||
# only use to avoid nan value in last mlp forward running glm4-9b-chat
|
# only use to avoid nan value in last mlp forward running glm4-9b-chat
|
||||||
modules = ["35.mlp", "36.mlp", "37.mlp", "38.mlp", "39.mlp"]
|
modules = ["35.mlp", "36.mlp", "37.mlp", "38.mlp", "39.mlp"]
|
||||||
else:
|
else:
|
||||||
|
|
@ -263,22 +107,34 @@ def get_load_function(low_bit):
|
||||||
|
|
||||||
self.model_memory_usage = m.consumed_memory
|
self.model_memory_usage = m.consumed_memory
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
logger.info(f"Loading model weights took "
|
logger.info("Loading model weights took %.4f GB",
|
||||||
f"{self.model_memory_usage / float(2**30):.4f} GB")
|
self.model_memory_usage / float(2**30))
|
||||||
|
|
||||||
if self.lora_config:
|
|
||||||
invalidInputError(hasattr(self.model, "supported_lora_modules")
|
|
||||||
and self.model.supported_lora_modules,
|
|
||||||
"Model does not support LoRA")
|
|
||||||
invalidInputError(hasattr(self.model, "embedding_modules"),
|
|
||||||
"Model does not have embedding_modules")
|
|
||||||
invalidInputError(hasattr(self.model, "embedding_padding_modules"),
|
|
||||||
"Model does not have embedding_padding_modules")
|
|
||||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
|
||||||
self.scheduler_config.max_num_seqs,
|
|
||||||
self.scheduler_config.max_num_batched_tokens +
|
|
||||||
self.scheduler_config.max_paddings, self.vocab_size,
|
|
||||||
self.lora_config, self.device, self.model.embedding_modules,
|
|
||||||
self.model.embedding_padding_modules)
|
|
||||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
|
||||||
return _ipex_llm_load_model
|
return _ipex_llm_load_model
|
||||||
|
|
||||||
|
|
||||||
|
def padding_mlp(module: torch.nn.Module):
|
||||||
|
if isinstance(module, Qwen2MLP):
|
||||||
|
hidden_size = module.down_proj.output_size
|
||||||
|
# devide by rank
|
||||||
|
intermediate_size = module.down_proj.input_size_per_partition
|
||||||
|
padding_size = 256
|
||||||
|
padding_intermediate_size = \
|
||||||
|
(intermediate_size + padding_size - 1) // padding_size * padding_size
|
||||||
|
if intermediate_size % padding_size == 0:
|
||||||
|
return
|
||||||
|
gate_up_weight = module.gate_up_proj.weight.data
|
||||||
|
new_gate_up_weight = torch.zeros([padding_intermediate_size * 2, hidden_size],
|
||||||
|
dtype=gate_up_weight.dtype, device=gate_up_weight.device)
|
||||||
|
# merge_gate_up_weight
|
||||||
|
new_gate_up_weight[:intermediate_size, :] = gate_up_weight[:intermediate_size, :]
|
||||||
|
new_gate_up_weight[padding_intermediate_size:padding_intermediate_size+intermediate_size, :] = gate_up_weight[intermediate_size:, :] # noqa
|
||||||
|
module.gate_up_proj.output_size_per_partition = padding_intermediate_size * 2
|
||||||
|
module.gate_up_proj.weight = torch.nn.Parameter(new_gate_up_weight, requires_grad=False)
|
||||||
|
|
||||||
|
down_weight = module.down_proj.weight.data
|
||||||
|
new_down_weight = torch.zeros([hidden_size, padding_intermediate_size],
|
||||||
|
dtype=down_weight.dtype, device=down_weight.device)
|
||||||
|
new_down_weight[:, :intermediate_size] = down_weight
|
||||||
|
module.down_proj.input_size_per_partition = padding_intermediate_size
|
||||||
|
module.down_proj.weight = torch.nn.Parameter(new_down_weight, requires_grad=False)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue