From 69c8d36f166152af4399860c42477696b6a37ee0 Mon Sep 17 00:00:00 2001 From: Guancheng Fu <110874468+gc-fu@users.noreply.github.com> Date: Tue, 10 Sep 2024 15:37:43 +0800 Subject: [PATCH] Switching from vLLM v0.3.3 to vLLM 0.5.4 (#12042) * Enable single card sync engine * enable ipex-llm optimizations for vllm * enable optimizations for lm_head * Fix chatglm multi-reference problem * Remove duplicate layer * LLM: Update vLLM to v0.5.4 (#11746) * Enable single card sync engine * enable ipex-llm optimizations for vllm * enable optimizations for lm_head * Fix chatglm multi-reference problem * update 0.5.4 api_server * add dockerfile * fix * fix * refine * fix --------- Co-authored-by: gc-fu * Add vllm-0.5.4 Dockerfile (#11838) * Update BIGDL_LLM_SDP_IGNORE_MASK in start-vllm-service.sh (#11957) * Fix vLLM not convert issues (#11817) (#11918) * Fix not convert issues * refine Co-authored-by: Guancheng Fu <110874468+gc-fu@users.noreply.github.com> * Fix glm4-9b-chat nan error on vllm 0.5.4 (#11969) * init * update mlp forward * fix minicpm error in vllm 0.5.4 * fix dependabot alerts (#12008) * Update 0.5.4 dockerfile (#12021) * Add vllm awq loading logic (#11987) * [ADD] Add vllm awq loading logic * [FIX] fix the module.linear_method path * [FIX] fix quant_config path error * Enable Qwen padding mlp to 256 to support batch_forward (#12030) * Enable padding mlp * padding to 256 * update style * Install 27191 runtime in 0.5.4 docker image (#12040) * fix rebase error * fix rebase error * vLLM: format for 0.5.4 rebase (#12043) * format * Update model_convert.py * Fix serving docker related modifications (#12046) * Fix undesired modifications (#12048) * fix * Refine offline_inference arguments --------- Co-authored-by: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com> Co-authored-by: Jun Wang Co-authored-by: Wang, Jian4 <61138589+hzjane@users.noreply.github.com> Co-authored-by: liu-shaojun Co-authored-by: Shaojun Liu <61072813+liu-shaojun@users.noreply.github.com> --- docker/llm/serving/xpu/docker/Dockerfile | 120 +++-- .../xpu/docker/benchmark_vllm_throughput.py | 59 ++- .../serving/xpu/docker/start-vllm-service.sh | 3 +- .../xpu/docker/vllm_offline_inference.py | 6 +- python/llm/dev/test/lint-python | 2 +- .../llm/src/ipex_llm/transformers/convert.py | 15 +- .../ipex_llm/transformers/low_bit_linear.py | 6 +- .../src/ipex_llm/vllm/xpu/engine/engine.py | 78 +-- .../vllm/xpu/entrypoints/openai/api_server.py | 449 ++++++++++------- .../vllm/xpu/entrypoints/openai/cli_args.py | 163 ++++++ .../vllm/xpu/entrypoints/openai/rpc/server.py | 221 +++++++++ .../vllm/xpu/ipex_llm_gpu_executor.py | 466 ------------------ .../src/ipex_llm/vllm/xpu/ipex_llm_wrapper.py | 24 + .../src/ipex_llm/vllm/xpu/model_convert.py | 300 +++-------- 14 files changed, 903 insertions(+), 1009 deletions(-) create mode 100644 python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py create mode 100644 python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/rpc/server.py delete mode 100644 python/llm/src/ipex_llm/vllm/xpu/ipex_llm_gpu_executor.py create mode 100644 python/llm/src/ipex_llm/vllm/xpu/ipex_llm_wrapper.py diff --git a/docker/llm/serving/xpu/docker/Dockerfile b/docker/llm/serving/xpu/docker/Dockerfile index 7f2f41bd..bdaeac50 100644 --- a/docker/llm/serving/xpu/docker/Dockerfile +++ b/docker/llm/serving/xpu/docker/Dockerfile @@ -1,60 +1,90 @@ -FROM intelanalytics/ipex-llm-serving-xpu:latest as build +FROM intel/oneapi-basekit:2024.1.1-devel-ubuntu22.04 ARG http_proxy ARG https_proxy -ADD ./oneccl-binding.patch /tmp/oneccl-binding.patch +ENV TZ=Asia/Shanghai +ENV PYTHONUNBUFFERED=1 -RUN cd /tmp/ && \ - pip install --upgrade setuptools wheel twine && \ - pip install "setuptools<70.0.0" && \ - git clone https://github.com/intel/torch-ccl -b v2.1.100+xpu && \ - cd torch-ccl && \ - patch -p1 < /tmp/oneccl-binding.patch && \ - git submodule sync && \ - git submodule update --init --recursive && \ - COMPUTE_BACKEND=dpcpp python setup.py sdist bdist_wheel && \ - mv /tmp/torch-ccl/dist/oneccl_bind_pt-2.1.100+xpu-cp311-cp311-linux_x86_64.whl /tmp/ - - -FROM intelanalytics/ipex-llm-xpu:2.2.0-SNAPSHOT - -ARG http_proxy -ARG https_proxy # Disable pip's cache behavior ARG PIP_NO_CACHE_DIR=false -COPY --from=build /tmp/oneccl_bind_pt-2.1.100+xpu-cp311-cp311-linux_x86_64.whl /tmp/ -ADD ./gradio_web_server.patch /tmp/gradio_web_server.patch +ADD ./gradio_web_server.patch /tmp/gradio_web_server.patch +ADD ./oneccl-binding.patch /tmp/oneccl-binding.patch -# Install Serving Dependencies -# Install ipex-llm[serving] only will update ipex_llm source code without updating -# bigdl-core-xe, which will lead to problems -RUN apt-get update && \ - apt-get install -y --no-install-recommends libfabric-dev wrk libaio-dev && \ - apt-get install -y intel-opencl-icd intel-level-zero-gpu=1.3.26241.33-647~22.04 level-zero level-zero-dev --allow-downgrades && \ - pip install --pre --upgrade ipex-llm[xpu,serving] && \ - pip install transformers==4.37.0 gradio==4.19.2 && \ - # Install vLLM-v2 dependencies - git clone -b sycl_xpu https://github.com/analytics-zoo/vllm.git /llm/vllm && \ - pip install -r /llm/vllm/requirements-xpu.txt && \ - pip install --no-deps xformers && \ - VLLM_BUILD_XPU_OPS=1 pip install --no-build-isolation -v -e /llm/vllm && \ - pip install outlines==0.0.34 --no-deps && \ - pip install interegular cloudpickle diskcache joblib lark nest-asyncio numba scipy && \ - # For Qwen series models support +RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \ + echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \ + chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \ + rm /etc/apt/sources.list.d/intel-graphics.list && \ + wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \ + echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \ + chmod 644 /usr/share/keyrings/intel-graphics.gpg && \ + apt-get update && \ + apt-get install -y --no-install-recommends curl wget git libunwind8-dev vim less && \ + ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && \ + env DEBIAN_FRONTEND=noninteractive apt-get update && \ + # add-apt-repository requires gnupg, gpg-agent, software-properties-common + apt-get install -y --no-install-recommends gnupg gpg-agent software-properties-common && \ + # Add Python 3.11 PPA repository + add-apt-repository ppa:deadsnakes/ppa -y && \ + apt-get install -y --no-install-recommends python3.11 git curl wget && \ + rm /usr/bin/python3 && \ + ln -s /usr/bin/python3.11 /usr/bin/python3 && \ + ln -s /usr/bin/python3 /usr/bin/python && \ + apt-get install -y --no-install-recommends python3-pip python3.11-dev python3-wheel python3.11-distutils && \ + wget https://bootstrap.pypa.io/get-pip.py -O get-pip.py && \ + # Install FastChat from source requires PEP 660 support + python3 get-pip.py && \ + rm get-pip.py && \ + pip install --upgrade requests argparse urllib3 && \ + pip install --pre --upgrade ipex-llm[xpu,serving] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ && \ + pip install transformers==4.36.2 && \ pip install transformers_stream_generator einops tiktoken && \ - # For pipeline serving support - pip install mpi4py fastapi uvicorn openai && \ - # for gradio web UI - pip install gradio && \ - # Install internal oneccl && \ + pip install --upgrade colorama && \ + # Download all-in-one benchmark and examples + git clone https://github.com/intel-analytics/ipex-llm && \ + cp -r ./ipex-llm/python/llm/dev/benchmark/ ./benchmark && \ + cp -r ./ipex-llm/python/llm/example/GPU/HuggingFace/LLM ./examples && \ + # Install vllm dependencies + pip install --upgrade fastapi && \ + pip install --upgrade "uvicorn[standard]" && \ + # Download vLLM-Serving + cp -r ./ipex-llm/python/llm/example/GPU/vLLM-Serving/ ./vLLM-Serving && \ + rm -rf ./ipex-llm && \ + # Install torch-ccl cd /tmp/ && \ + pip install torch==2.1.0.post2 torchvision==0.16.0.post2 torchaudio==2.1.0.post2 intel-extension-for-pytorch==2.1.30.post0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ && \ + # Internal oneccl wget https://sourceforge.net/projects/oneccl-wks/files/oneccl_wks_installer_2024.0.0.2.sh && \ bash oneccl_wks_installer_2024.0.0.2.sh && \ - pip uninstall -y oneccl_bind_pt && \ - pip install /tmp/oneccl_bind_pt-2.1.100+xpu-cp311-cp311-linux_x86_64.whl && \ - rm /tmp/oneccl_bind_pt-2.1.100+xpu-cp311-cp311-linux_x86_64.whl && \ + git clone https://github.com/intel/torch-ccl -b v2.1.300+xpu && \ + cd torch-ccl && \ + patch -p1 < /tmp/oneccl-binding.patch && \ + USE_SYSTEM_ONECCL=ON COMPUTE_BACKEND=dpcpp python setup.py install && \ + apt-get update && \ + apt-get install -y --no-install-recommends libfabric-dev wrk libaio-dev numactl && \ + # apt-get install -y intel-opencl-icd intel-level-zero-gpu=1.3.26241.33-647~22.04 level-zero level-zero-dev --allow-downgrades && \ + mkdir -p /tmp/neo && \ + cd /tmp/neo && \ + wget https://github.com/intel/intel-graphics-compiler/releases/download/igc-1.0.15136.4/intel-igc-core_1.0.15136.4_amd64.deb && \ + wget https://github.com/intel/intel-graphics-compiler/releases/download/igc-1.0.15136.4/intel-igc-opencl_1.0.15136.4_amd64.deb && \ + wget https://github.com/intel/compute-runtime/releases/download/23.35.27191.9/intel-level-zero-gpu-dbgsym_1.3.27191.9_amd64.ddeb && \ + wget https://github.com/intel/compute-runtime/releases/download/23.35.27191.9/intel-level-zero-gpu_1.3.27191.9_amd64.deb && \ + wget https://github.com/intel/compute-runtime/releases/download/23.35.27191.9/intel-opencl-icd-dbgsym_23.35.27191.9_amd64.ddeb && \ + wget https://github.com/intel/compute-runtime/releases/download/23.35.27191.9/intel-opencl-icd_23.35.27191.9_amd64.deb && \ + wget https://github.com/intel/compute-runtime/releases/download/23.35.27191.9/libigdgmm12_22.3.11.ci17747749_amd64.deb && \ + dpkg -i *.deb && \ + rm -rf /tmp/neo && \ + mkdir -p /llm && \ + cd /llm && \ + git clone -b 0.5.4 https://github.com/analytics-zoo/vllm.git /llm/vllm && \ + cd /llm/vllm && \ + pip install -r /llm/vllm/requirements-xpu.txt && \ + VLLM_TARGET_DEVICE=xpu python setup.py install && \ + pip install mpi4py fastapi uvicorn openai && \ + pip install gradio && \ + # patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch && \ + pip install ray && \ patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch COPY ./vllm_online_benchmark.py /llm/ @@ -66,5 +96,7 @@ COPY ./start-fastchat-service.sh /llm/ COPY ./start-pp_serving-service.sh /llm/ COPY ./start-lightweight_serving-service.sh /llm/ +ENV LD_LIBRARY_PATH /usr/local/lib/python3.11/dist-packages/intel_extension_for_pytorch/lib/:/opt/intel/oneapi/tbb/2021.12/env/../lib/intel64/gcc4.8:/opt/intel/oneapi/mpi/2021.12/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/2021.12/lib:/opt/intel/oneapi/mkl/2024.1/lib:/opt/intel/oneapi/ippcp/2021.11/lib/:/opt/intel/oneapi/ipp/2021.11/lib:/opt/intel/oneapi/dpl/2022.5/lib:/opt/intel/oneapi/dnnl/2024.1/lib:/opt/intel/oneapi/debugger/2024.1/opt/debugger/lib:/opt/intel/oneapi/dal/2024.2/lib:/opt/intel/oneapi/compiler/2024.1/opt/oclfpga/host/linux64/lib:/opt/intel/oneapi/compiler/2024.1/opt/compiler/lib:/opt/intel/oneapi/compiler/2024.1/lib:/opt/intel/oneapi/ccl/2021.12/lib/ +ENV BIGDL_LLM_SDP_IGNORE_MASK 0 WORKDIR /llm/ diff --git a/docker/llm/serving/xpu/docker/benchmark_vllm_throughput.py b/docker/llm/serving/xpu/docker/benchmark_vllm_throughput.py index 94e04584..28e94da1 100644 --- a/docker/llm/serving/xpu/docker/benchmark_vllm_throughput.py +++ b/docker/llm/serving/xpu/docker/benchmark_vllm_throughput.py @@ -103,41 +103,38 @@ def run_vllm( warm_prompt = "hi " * (1024 - 1) warm_requests = [(warm_prompt, 1024, 1024) for _ in range(8)] - for prompt, _, output_len in warm_requests: - sampling_params = SamplingParams( - n=n, - temperature=0.0 if use_beam_search else 1.0, - top_p=1.0, - use_beam_search=use_beam_search, - ignore_eos=True, - max_tokens=output_len, - ) - llm._add_request( - prompt=prompt, - prompt_token_ids=None, - sampling_params=sampling_params, - ) - llm._run_engine(use_tqdm=True) + prompts: List[str] = [] + sampling_params: List[SamplingParams] = [] + for prompt, _, output_len in warm_requests: + prompts.append(prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + )) + llm.generate(prompts, sampling_params, use_tqdm=True) + + prompts: List[str] = [] + sampling_params: List[SamplingParams] = [] for prompt, _, output_len in requests: - sampling_params = SamplingParams( - n=n, - temperature=0.0 if use_beam_search else 1.0, - top_p=1.0, - use_beam_search=use_beam_search, - ignore_eos=True, - max_tokens=output_len, - ) - # FIXME(woosuk): Do not use internal method. - llm._add_request( - prompt=prompt, - prompt_token_ids=None, - sampling_params=sampling_params, - ) + prompts.append(prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + )) start = time.perf_counter() - # FIXME(woosuk): Do not use internal method. - llm._run_engine(use_tqdm=True) + llm.generate(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() return end - start diff --git a/docker/llm/serving/xpu/docker/start-vllm-service.sh b/docker/llm/serving/xpu/docker/start-vllm-service.sh index c0d0f112..a7860a91 100644 --- a/docker/llm/serving/xpu/docker/start-vllm-service.sh +++ b/docker/llm/serving/xpu/docker/start-vllm-service.sh @@ -3,6 +3,7 @@ model="YOUR_MODEL_PATH" served_model_name="YOUR_MODEL_NAME" source /opt/intel/1ccl-wks/setvars.sh +export BIGDL_LLM_SDP_IGNORE_MASK=0 python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \ --served-model-name $served_model_name \ @@ -17,4 +18,4 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \ --max-model-len 4096 \ --max-num-batched-tokens 10240 \ --max-num-seqs 12 \ - --tensor-parallel-size 1 \ No newline at end of file + --tensor-parallel-size 1 diff --git a/docker/llm/serving/xpu/docker/vllm_offline_inference.py b/docker/llm/serving/xpu/docker/vllm_offline_inference.py index 15ecf4f9..4f09483f 100644 --- a/docker/llm/serving/xpu/docker/vllm_offline_inference.py +++ b/docker/llm/serving/xpu/docker/vllm_offline_inference.py @@ -49,8 +49,10 @@ llm = LLM(model="YOUR_MODEL", device="xpu", dtype="float16", enforce_eager=True, - load_in_low_bit="sym_int4", - tensor_parallel_size=1) + load_in_low_bit="fp8", + tensor_parallel_size=1, + max_model_len=2000, + max_num_batched_tokens=2000) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/python/llm/dev/test/lint-python b/python/llm/dev/test/lint-python index 08c49b00..f8345e04 100755 --- a/python/llm/dev/test/lint-python +++ b/python/llm/dev/test/lint-python @@ -21,7 +21,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" PYTHON_ROOT_DIR="$SCRIPT_DIR/.." echo $PYTHON_ROOT_DIR PATHS_TO_CHECK="$SCRIPT_DIR/../../src" -PATTERNS_TO_EXCLUDE="__init__.py,log4Error.py,$SCRIPT_DIR/../../src/ipex_llm/langchain/*,$SCRIPT_DIR/../../src/ipex_llm/transformers/gguf/models/model_implement/yuan2/*,benchmark_util_4_29.py,benchmark_util_4_42.py,benchmark_util_4_43.py,tgi_api_server.py" +PATTERNS_TO_EXCLUDE="__init__.py,log4Error.py,$SCRIPT_DIR/../../src/ipex_llm/langchain/*,$SCRIPT_DIR/../../src/ipex_llm/transformers/gguf/models/model_implement/yuan2/*,benchmark_util_4_29.py,benchmark_util_4_42.py,benchmark_util_4_43.py,tgi_api_server.py,api_server.py" PEP8_REPORT_PATH="$PYTHON_ROOT_DIR/test/pep8-report.txt" PYLINT_REPORT_PATH="$PYTHON_ROOT_DIR/test/pylint-report.txt" PYLINT_INSTALL_INFO="$PYTHON_ROOT_DIR/test/pylint-info.txt" diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index d3fd44c3..159d3171 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -160,7 +160,7 @@ def is_linear_module(module): if is_module_in_classes(module, VLLM_LINEAR_LIST): if 'xpu' in _VLLM_VERSION: # For vllm xpu - from vllm.model_executor.parallel_utils.parallel_state import ( + from vllm.distributed.parallel_state import ( get_tensor_model_parallel_group, get_tensor_model_parallel_world_size ) @@ -183,8 +183,8 @@ def is_linear_module(module): mp_group = None # Check for attribute qweight if (not _USE_VLLM_AWQ - and hasattr(module.linear_method, "quant_config") - and module.linear_method.quant_config.get_name() == "awq"): + and hasattr(module.quant_method, "quant_config") + and module.quant_method.quant_config.get_name() == "awq"): _USE_VLLM_AWQ = True invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not" " support linear layers with skip_bias_add argument") @@ -231,7 +231,6 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype, from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from ipex_llm.transformers.low_bit_linear import LowBitLinear, \ FP16Linear, BF16Linear, vLLMLowBitLinear, vLLMFP16Linear, vLLMBF16Linear - # Currently, vLLM does not support optimize_lm_head = True optimize_lm_head = False if isinstance(module, ParallelLMHead): if qtype == ggml_tensor_qtype["fp16"]: @@ -301,7 +300,7 @@ def convert_vllm_awq(module): dtype=torch.int32) * 4).unsqueeze(0) # vLLM only supports load 4-bits model, so this has been checked bits = 4 - group_size = module.linear_method.quant_config.group_size + group_size = module.quant_method.quant_config.group_size zeros = torch.bitwise_right_shift( torch.unsqueeze(module.qzeros, 2).expand(-1, -1, 32 // bits), @@ -466,6 +465,12 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if any(key in full_module_name for key in modules_to_not_convert): continue + if is_linear and getattr(model_config, "model_type", None) == "chatglm" and \ + name == "lm_head": + # Now we re-reference it to output_layer + model._modules[name] = model._modules["transformer"]._modules["output_layer"] + continue + if is_linear and not isinstance(module, LowBitLinear): in_features, out_features, mp_group = linear_args optimize_lm_head = ( diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index fddcb7c9..d30126a6 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -802,7 +802,7 @@ class LowBitLinear(nn.Linear): result = result.view(new_shape) if self.mp_group is not None: if get_use_vllm(): - torch.distributed.all_reduce(result, group=self.mp_group) + result = self.mp_group.all_reduce(result) elif is_deepspeed_available(): from deepspeed import comm as dist dist.inference_all_reduce(result, group=self.mp_group) @@ -889,7 +889,7 @@ class FP16Linear(nn.Linear): result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias) if self.mp_group is not None: if get_use_vllm(): - torch.distributed.all_reduce(result, group=self.mp_group) + result = self.mp_group.all_reduce(result) elif is_deepspeed_available(): from deepspeed import comm as dist dist.inference_all_reduce(result, group=self.mp_group) @@ -926,7 +926,7 @@ class FP16Linear(nn.Linear): result = result.view(new_shape) if self.mp_group is not None: if get_use_vllm(): - torch.distributed.all_reduce(result, group=self.mp_group) + result = self.mp_group.all_reduce(result) elif is_deepspeed_available(): from deepspeed import comm as dist dist.inference_all_reduce(result, group=self.mp_group) diff --git a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py index c640cb00..0a3a8741 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py +++ b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py @@ -13,16 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -from typing import List, Optional, Union +from typing import Dict, Optional from vllm.engine.llm_engine import LLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs -from vllm.engine.ray_utils import initialize_ray_cluster from vllm.entrypoints.llm import LLM from vllm.utils import Counter from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert -from ipex_llm.utils.common import invalidInputError +from vllm.usage.usage_lib import UsageContext +from vllm.engine.metrics import StatLoggerBase class IPEXLLMAsyncLLMEngine(AsyncLLMEngine): @@ -34,35 +33,14 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine): cls, engine_args: AsyncEngineArgs, start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, load_in_low_bit: str = "sym_int4", - # ipex_llm_optimize_mode: str = 'NATIVE', + stat_loggers: Optional[Dict[str, StatLoggerBase]]=None, ) -> "AsyncLLMEngine": """Creates an async LLM engine from the engine arguments.""" - # Enable ipex-llm optimizations - engine_configs = engine_args.create_engine_configs() - + # Create the engine configs. _ipex_llm_convert(load_in_low_bit) - parallel_config = engine_configs[2] - if parallel_config.worker_use_ray or engine_args.engine_use_ray: - initialize_ray_cluster(parallel_config) - # from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync - from ipex_llm.vllm.xpu.ipex_llm_gpu_executor import get_gpu_executor_class_async - executor_class = get_gpu_executor_class_async(load_in_low_bit) - else: - invalidInputError(parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.")) - from vllm.executor.gpu_executor import GPUExecutorAsync - executor_class = GPUExecutorAsync - # Create the async LLM engine. - engine = cls(parallel_config.worker_use_ray, - engine_args.engine_use_ray, - *engine_configs, - executor_class, - log_requests=not engine_args.disable_log_requests, - log_stats=not engine_args.disable_log_stats, - max_log_len=engine_args.max_log_len, - start_engine_loop=start_engine_loop) - return engine + return super().from_engine_args(engine_args, start_engine_loop, usage_context, stat_loggers) class IPEXLLMClass(LLM): @@ -71,6 +49,7 @@ class IPEXLLMClass(LLM): model: str, tokenizer: Optional[str] = None, tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, trust_remote_code: bool = False, tensor_parallel_size: int = 1, dtype: str = "auto", @@ -80,18 +59,26 @@ class IPEXLLMClass(LLM): seed: int = 0, gpu_memory_utilization: float = 0.9, swap_space: int = 4, + cpu_offload_gb: float = 0, enforce_eager: bool = False, - max_context_len_to_capture: int = 8192, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, load_in_low_bit: str = "sym_int4", **kwargs, ) -> None: if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True + removed_vision_keys = ("image_token_id", "image_feature_size", + "image_input_shape", "image_input_type") + if any(k in kwargs for k in removed_vision_keys): + raise TypeError( # noqa + "There is no need to pass vision-related arguments anymore.") engine_args = EngineArgs( model=model, tokenizer=tokenizer, tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, trust_remote_code=trust_remote_code, tensor_parallel_size=tensor_parallel_size, dtype=dtype, @@ -101,13 +88,16 @@ class IPEXLLMClass(LLM): seed=seed, gpu_memory_utilization=gpu_memory_utilization, swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) - self.llm_engine = IPEXLLMLLMEngine.from_engine_args(engine_args, - load_in_low_bit=load_in_low_bit) + self.llm_engine = IPEXLLMLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.LLM_CLASS, + load_in_low_bit=load_in_low_bit) self.request_counter = Counter() @@ -119,29 +109,11 @@ class IPEXLLMLLMEngine(LLMEngine): def from_engine_args( cls, engine_args: EngineArgs, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]]=None, load_in_low_bit: str = "sym_int4", - # ipex_llm_optimize_mode: str = 'NATIVE', ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" # Create the engine configs. - engine_configs = engine_args.create_engine_configs() _ipex_llm_convert(load_in_low_bit) - parallel_config = engine_configs[2] - - # Initialize the cluster and specify the executor class. - if parallel_config.worker_use_ray: - initialize_ray_cluster(parallel_config) - # from vllm.executor.ray_gpu_executor import RayGPUExecutor - from ipex_llm.vllm.xpu.ipex_llm_gpu_executor import get_gpu_executor_class - executor_class = get_gpu_executor_class(load_in_low_bit) - else: - invalidInputError(parallel_config.world_size == 1, - "Ray is required if parallel_config.world_size > 1.") - from vllm.executor.gpu_executor import GPUExecutor - executor_class = GPUExecutor - - # Create the LLM engine. - engine = cls(*engine_configs, - executor_class=executor_class, - log_stats=not engine_args.disable_log_stats) - return engine + return super().from_engine_args(engine_args, usage_context, stat_loggers) diff --git a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py index 2b6a0614..b6ea51c1 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py +++ b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/api_server.py @@ -1,199 +1,191 @@ -import argparse import asyncio -import json -from contextlib import asynccontextmanager -import os import importlib import inspect -import ssl - -from prometheus_client import make_asgi_app -import fastapi -import uvicorn +import re +from argparse import Namespace +from contextlib import asynccontextmanager from http import HTTPStatus -from fastapi import Request +from multiprocessing import Process +from typing import AsyncIterator, Set + +from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, StreamingResponse, Response +from fastapi.responses import JSONResponse, Response, StreamingResponse +from prometheus_client import make_asgi_app +from starlette.routing import Mount -import vllm +import vllm.envs as envs +from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.entrypoints.openai.protocol import (CompletionRequest, - ChatCompletionRequest, - ErrorResponse) -from vllm.logger import init_logger +from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.logger import RequestLogger +from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + DetokenizeRequest, + DetokenizeResponse, + EmbeddingRequest, ErrorResponse, + TokenizeRequest, + TokenizeResponse) +from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient +from ipex_llm.vllm.xpu.entrypoints.openai.rpc.server import run_rpc_server +# yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion -from vllm.entrypoints.openai.serving_engine import LoRA -from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine -from ipex_llm.utils.common import invalidInputError +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_tokenization import ( + OpenAIServingTokenization) +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, get_open_port +from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds -openai_serving_chat: OpenAIServingChat = None -openai_serving_completion: OpenAIServingCompletion = None -logger = init_logger(__name__) +async_engine_client: AsyncEngineClient +engine_args: AsyncEngineArgs +openai_serving_chat: OpenAIServingChat +openai_serving_completion: OpenAIServingCompletion +openai_serving_embedding: OpenAIServingEmbedding +openai_serving_tokenization: OpenAIServingTokenization + +logger = init_logger('vllm.entrypoints.openai.api_server') + +_running_tasks: Set[asyncio.Task] = set() + + +def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool: + return ModelConfig(model=model_name, + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=trust_remote_code, + seed=0, + dtype="float16").embedding_mode @asynccontextmanager -async def lifespan(app: fastapi.FastAPI): +async def lifespan(app: FastAPI): async def _force_log(): while True: await asyncio.sleep(10) - await engine.do_log_stats() + await async_engine_client.do_log_stats() if not engine_args.disable_log_stats: - asyncio.create_task(_force_log()) + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) yield -app = fastapi.FastAPI(lifespan=lifespan) +@asynccontextmanager +async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: + # Context manager to handle async_engine_client lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + global engine_args + engine_args = AsyncEngineArgs.from_cli_args(args) + + # Backend itself still global for the silly lil' health handler + global async_engine_client + + # If manually triggered or embedding model, use AsyncLLMEngine in process. + # TODO: support embedding model via RPC. + if (model_is_embedding(args.model, args.trust_remote_code) + or args.disable_frontend_multiprocessing): + async_engine_client = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER, + load_in_low_bit=args.load_in_low_bit) + yield async_engine_client + return + + # Otherwise, use the multiprocessing AsyncLLMEngine. + else: + # Start RPCServer in separate process (holds the AsyncLLMEngine). + port = get_open_port(envs.VLLM_RPC_PORT) + load_in_low_bit = args.load_in_low_bit + rpc_server_process = Process(target=run_rpc_server, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + port, load_in_low_bit)) + rpc_server_process.start() + + # Build RPCClient, which conforms to AsyncEngineClient Protocol. + async_engine_client = AsyncEngineRPCClient(port) + await async_engine_client.setup() + + try: + yield async_engine_client + finally: + # Ensure rpc server process was terminated + rpc_server_process.terminate() + + # Close all open connections to the backend + async_engine_client.close() + + # Wait for server process to join + rpc_server_process.join() -class LoRAParserAction(argparse.Action): - - def __call__(self, parser, namespace, values, option_string=None): - lora_list = [] - for item in values: - name, path = item.split('=') - lora_list.append(LoRA(name, path)) - setattr(namespace, self.dest, lora_list) +router = APIRouter() -def parse_args(): - parser = argparse.ArgumentParser( - description="vLLM OpenAI-Compatible RESTful API server.") - parser.add_argument("--host", type=str, default=None, help="host name") - parser.add_argument("--port", type=int, default=8000, help="port number") - parser.add_argument( - "--uvicorn-log-level", - type=str, - default="info", - choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'], - help="log level for uvicorn") - parser.add_argument("--allow-credentials", - action="store_true", - help="allow credentials") - parser.add_argument("--allowed-origins", - type=json.loads, - default=["*"], - help="allowed origins") - parser.add_argument("--allowed-methods", - type=json.loads, - default=["*"], - help="allowed methods") - parser.add_argument("--allowed-headers", - type=json.loads, - default=["*"], - help="allowed headers") - parser.add_argument("--api-key", - type=str, - default=None, - help="If provided, the server will require this key " - "to be presented in the header.") - parser.add_argument("--served-model-name", - type=str, - default=None, - help="The model name used in the API. If not " - "specified, the model name will be the same as " - "the huggingface name.") - parser.add_argument( - "--lora-modules", - type=str, - default=None, - nargs='+', - action=LoRAParserAction, - help="LoRA module configurations in the format name=path. " - "Multiple modules can be specified.") - parser.add_argument("--chat-template", - type=str, - default=None, - help="The file path to the chat template, " - "or the template in single-line form " - "for the specified model") - parser.add_argument("--response-role", - type=str, - default="assistant", - help="The role name to return if " - "`request.add_generation_prompt=true`.") - parser.add_argument("--ssl-keyfile", - type=str, - default=None, - help="The file path to the SSL key file") - parser.add_argument("--ssl-certfile", - type=str, - default=None, - help="The file path to the SSL cert file") - parser.add_argument("--ssl-ca-certs", - type=str, - default=None, - help="The CA certificates file") - parser.add_argument( - "--ssl-cert-reqs", - type=int, - default=int(ssl.CERT_NONE), - help="Whether client certificate is required (see stdlib ssl module's)" - ) - parser.add_argument( - "--root-path", - type=str, - default=None, - help="FastAPI root_path when app is behind a path based routing proxy") - parser.add_argument( - "--middleware", - type=str, - action="append", - default=[], - help="Additional ASGI middleware to apply to the app. " - "We accept multiple --middleware arguments. " - "The value should be an import path. " - "If a function is provided, vLLM will add it to the server " - "using @app.middleware('http'). " - "If a class is provided, vLLM will add it to the server " - "using app.add_middleware(). ") - parser.add_argument( - "--load-in-low-bit", - type=str, - default="sym_int4", - help="Low-bit quantization for IPEX-LLM models") - - parser = AsyncEngineArgs.add_cli_args(parser) - return parser.parse_args() +def mount_metrics(app: FastAPI): + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app()) + # Workaround for 307 Redirect for /metrics + metrics_route.path_regex = re.compile('^/metrics(?P.*)$') + app.routes.append(metrics_route) -# Add prometheus asgi middleware to route /metrics requests -metrics_app = make_asgi_app() -app.mount("/metrics", metrics_app) - - -@app.exception_handler(RequestValidationError) -async def validation_exception_handler(_, exc): - err = openai_serving_chat.create_error_response(message=str(exc)) - return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) - - -@app.get("/health") +@router.get("/health") async def health() -> Response: """Health check.""" - await openai_serving_chat.engine.check_health() + await async_engine_client.check_health() return Response(status_code=200) -@app.get("/v1/models") +@router.post("/tokenize") +async def tokenize(request: TokenizeRequest): + generator = await openai_serving_tokenization.create_tokenize(request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + else: + assert isinstance(generator, TokenizeResponse) + return JSONResponse(content=generator.model_dump()) + + +@router.post("/detokenize") +async def detokenize(request: DetokenizeRequest): + generator = await openai_serving_tokenization.create_detokenize(request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + else: + assert isinstance(generator, DetokenizeResponse) + return JSONResponse(content=generator.model_dump()) + + +@router.get("/v1/models") async def show_available_models(): - models = await openai_serving_chat.show_available_models() + models = await openai_serving_completion.show_available_models() return JSONResponse(content=models.model_dump()) -@app.get("/version") +@router.get("/version") async def show_version(): - ver = {"version": vllm.__version__} + ver = {"version": VLLM_VERSION} return JSONResponse(content=ver) -@app.post("/v1/chat/completions") +@router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): generator = await openai_serving_chat.create_chat_completion( @@ -205,10 +197,11 @@ async def create_chat_completion(request: ChatCompletionRequest, return StreamingResponse(content=generator, media_type="text/event-stream") else: + assert isinstance(generator, ChatCompletionResponse) return JSONResponse(content=generator.model_dump()) -@app.post("/v1/completions") +@router.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): generator = await openai_serving_completion.create_completion( request, raw_request) @@ -222,8 +215,23 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) -if __name__ == "__main__": - args = parse_args() +@router.post("/v1/embeddings") +async def create_embedding(request: EmbeddingRequest, raw_request: Request): + generator = await openai_serving_embedding.create_embedding( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + else: + return JSONResponse(content=generator.model_dump()) + + +def build_app(args: Namespace) -> FastAPI: + app = FastAPI(lifespan=lifespan) + app.include_router(router) + app.root_path = args.root_path + + mount_metrics(app) app.add_middleware( CORSMiddleware, @@ -233,11 +241,20 @@ if __name__ == "__main__": allow_headers=args.allowed_headers, ) - token = os.environ.get("VLLM_API_KEY") or args.api_key - if token: + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(_, exc): + err = openai_serving_chat.create_error_response(message=str(exc)) + return JSONResponse(err.model_dump(), + status_code=HTTPStatus.BAD_REQUEST) + + if token := envs.VLLM_API_KEY or args.api_key: + @app.middleware("http") async def authentication(request: Request, call_next): - if not request.url.path.startswith("/v1"): + root_path = "" if args.root_path is None else args.root_path + if request.method == "OPTIONS": + return await call_next(request) + if not request.url.path.startswith(f"{root_path}/v1"): return await call_next(request) if request.headers.get("Authorization") != "Bearer " + token: return JSONResponse(content={"error": "Unauthorized"}, @@ -252,34 +269,104 @@ if __name__ == "__main__": elif inspect.iscoroutinefunction(imported): app.middleware("http")(imported) else: - invalidInputError(False, (f"Invalid middleware {middleware}. " - f"Must be a function or a class.")) + raise ValueError(f"Invalid middleware {middleware}. " + f"Must be a function or a class.") - logger.info(f"vLLM API server version {vllm.__version__}") - logger.info(f"args: {args}") + return app + + +async def init_app( + async_engine_client: AsyncEngineClient, + args: Namespace, +) -> FastAPI: + app = build_app(args) if args.served_model_name is not None: - served_model = args.served_model_name + served_model_names = args.served_model_name else: - served_model = args.model + served_model_names = [args.model] - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = IPEXLLMAsyncLLMEngine.from_engine_args(engine_args, - load_in_low_bit=args.load_in_low_bit) - openai_serving_chat = OpenAIServingChat(engine, served_model, - args.response_role, - args.lora_modules, - args.chat_template) + model_config = await async_engine_client.get_model_config() + + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + + global openai_serving_chat + global openai_serving_completion + global openai_serving_embedding + global openai_serving_tokenization + + openai_serving_chat = OpenAIServingChat( + async_engine_client, + model_config, + served_model_names, + args.response_role, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + request_logger=request_logger, + chat_template=args.chat_template, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + ) openai_serving_completion = OpenAIServingCompletion( - engine, served_model, args.lora_modules) - + async_engine_client, + model_config, + served_model_names, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + ) + openai_serving_embedding = OpenAIServingEmbedding( + async_engine_client, + model_config, + served_model_names, + request_logger=request_logger, + ) + openai_serving_tokenization = OpenAIServingTokenization( + async_engine_client, + model_config, + served_model_names, + lora_modules=args.lora_modules, + request_logger=request_logger, + chat_template=args.chat_template, + ) app.root_path = args.root_path - uvicorn.run(app, - host=args.host, - port=args.port, - log_level=args.uvicorn_log_level, - timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs) + + return app + + +async def run_server(args, **uvicorn_kwargs) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + async with build_async_engine_client(args) as async_engine_client: + app = await init_app(async_engine_client, args) + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + await shutdown_task + + +if __name__ == "__main__": + # NOTE(simon): + # This section should be in sync with vllm/scripts.py for CLI entrypoints. + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server.") + parser = make_arg_parser(parser) + args = parser.parse_args() + + asyncio.run(run_server(args)) diff --git a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py new file mode 100644 index 00000000..70af2a03 --- /dev/null +++ b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py @@ -0,0 +1,163 @@ +""" +This file contains the command line arguments for the vLLM's +OpenAI-compatible server. It is kept in a separate file for documentation +purposes. +""" + +import argparse +import json +import ssl + +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + PromptAdapterPath) +from vllm.utils import FlexibleArgumentParser + + +class LoRAParserAction(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + lora_list = [] + for item in values: + name, path = item.split('=') + lora_list.append(LoRAModulePath(name, path)) + setattr(namespace, self.dest, lora_list) + + +class PromptAdapterParserAction(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + adapter_list = [] + for item in values: + name, path = item.split('=') + adapter_list.append(PromptAdapterPath(name, path)) + setattr(namespace, self.dest, adapter_list) + + +def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument("--host", + type=nullable_str, + default=None, + help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument( + "--uvicorn-log-level", + type=str, + default="info", + choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'], + help="log level for uvicorn") + parser.add_argument("--allow-credentials", + action="store_true", + help="allow credentials") + parser.add_argument("--allowed-origins", + type=json.loads, + default=["*"], + help="allowed origins") + parser.add_argument("--allowed-methods", + type=json.loads, + default=["*"], + help="allowed methods") + parser.add_argument("--allowed-headers", + type=json.loads, + default=["*"], + help="allowed headers") + parser.add_argument("--api-key", + type=nullable_str, + default=None, + help="If provided, the server will require this key " + "to be presented in the header.") + parser.add_argument( + "--lora-modules", + type=nullable_str, + default=None, + nargs='+', + action=LoRAParserAction, + help="LoRA module configurations in the format name=path. " + "Multiple modules can be specified.") + parser.add_argument( + "--prompt-adapters", + type=nullable_str, + default=None, + nargs='+', + action=PromptAdapterParserAction, + help="Prompt adapter configurations in the format name=path. " + "Multiple adapters can be specified.") + parser.add_argument("--chat-template", + type=nullable_str, + default=None, + help="The file path to the chat template, " + "or the template in single-line form " + "for the specified model") + parser.add_argument("--response-role", + type=nullable_str, + default="assistant", + help="The role name to return if " + "`request.add_generation_prompt=true`.") + parser.add_argument("--ssl-keyfile", + type=nullable_str, + default=None, + help="The file path to the SSL key file") + parser.add_argument("--ssl-certfile", + type=nullable_str, + default=None, + help="The file path to the SSL cert file") + parser.add_argument("--ssl-ca-certs", + type=nullable_str, + default=None, + help="The CA certificates file") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)" + ) + parser.add_argument( + "--root-path", + type=nullable_str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument( + "--middleware", + type=nullable_str, + action="append", + default=[], + help="Additional ASGI middleware to apply to the app. " + "We accept multiple --middleware arguments. " + "The value should be an import path. " + "If a function is provided, vLLM will add it to the server " + "using @app.middleware('http'). " + "If a class is provided, vLLM will add it to the server " + "using app.add_middleware(). ") + parser.add_argument( + "--return-tokens-as-token-ids", + action="store_true", + help="When --max-logprobs is specified, represents single tokens as " + "strings of the form 'token_id:{token_id}' so that tokens that " + "are not JSON-encodable can be identified.") + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + help="If specified, will run the OpenAI frontend server in the same " + "process as the model serving engine.") + parser.add_argument( + "--load-in-low-bit", + type=str, + default="sym_int4", + help="Low-bit quantization for IPEX-LLM models") + + parser = AsyncEngineArgs.add_cli_args(parser) + + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + '\n\nDefault: Unlimited') + + return parser + + +def create_parser_for_docs() -> FlexibleArgumentParser: + parser_for_docs = FlexibleArgumentParser( + prog="-m vllm.entrypoints.openai.api_server") + return make_arg_parser(parser_for_docs) diff --git a/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/rpc/server.py b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/rpc/server.py new file mode 100644 index 00000000..f9c7778c --- /dev/null +++ b/python/llm/src/ipex_llm/vllm/xpu/entrypoints/openai/rpc/server.py @@ -0,0 +1,221 @@ +import asyncio +import signal +from typing import Any, Coroutine + +import cloudpickle +import zmq +import zmq.asyncio +from typing_extensions import Never + +from vllm import AsyncEngineArgs +from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine + +from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCUtilityRequest) +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext + +logger = init_logger(__name__) + + +class AsyncEngineRPCServer: + + def __init__(self, async_engine_args: AsyncEngineArgs, + usage_context: UsageContext, port: int, load_in_low_bit: str): + # Initialize engine first. + self.engine = AsyncLLMEngine.from_engine_args(async_engine_args, + usage_context=usage_context, + load_in_low_bit=load_in_low_bit) + + # Initialize context. + self.context = zmq.asyncio.Context() + + # Init socket for readiness state. + self.socket = self.context.socket(zmq.constants.ROUTER) + # Note numeric form of localhost should be used for zmq bind(), + # see https://stackoverflow.com/a/8958414 + self.socket.bind(f"tcp://127.0.0.1:{port}") + + def cleanup(self): + """Cleanup all resources.""" + self.socket.close() + self.context.destroy() + + async def get_model_config(self, identity): + """Send the ModelConfig""" + model_config = await self.engine.get_model_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(model_config)]) + + async def get_decoding_config(self, identity): + """Send the DecodingConfig""" + decoding_config = await self.engine.get_decoding_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(decoding_config)]) + + async def get_lora_config(self, identity): + lora_config = await self.engine.get_lora_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(lora_config)]) + + async def get_scheduler_config(self, identity): + """Send the SchedulerConfig""" + parallel_config = await self.engine.get_scheduler_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(parallel_config)]) + + async def get_parallel_config(self, identity): + """Send the ParallelConfig""" + parallel_config = await self.engine.get_parallel_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(parallel_config)]) + + async def is_tracing_enabled(self, identity): + """Send the is_tracing_enabled flag""" + tracing_flag = await self.engine.is_tracing_enabled() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(tracing_flag)]) + + async def do_log_stats(self, identity): + """Log stats and confirm success.""" + await self.engine.do_log_stats() + + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + + async def is_server_ready(self, identity): + """Notify the client that we are ready.""" + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + + async def abort(self, identity, request: RPCAbortRequest): + """Abort request and notify the client of success.""" + # Abort the request in the llm engine. + await self.engine.abort(request.request_id) + + # Send confirmation to the client. + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + + async def generate(self, identity, generate_request: RPCGenerateRequest): + try: + results_generator = self.engine.generate( + generate_request.inputs, + sampling_params=generate_request.sampling_params, + request_id=generate_request.request_id, + lora_request=generate_request.lora_request, + trace_headers=generate_request.trace_headers, + prompt_adapter_request=generate_request.prompt_adapter_request) + + async for request_output in results_generator: + await self.socket.send_multipart( + [identity, cloudpickle.dumps(request_output)]) + + except Exception as e: + # Notify client of all failures + await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + + async def check_health(self, identity): + try: + await self.engine.check_health() + await self.socket.send_multipart( + [identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)]) + except Exception as e: + await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + + def _make_handler_coro(self, identity, + message) -> Coroutine[Any, Any, Never]: + """Route the zmq message to the handler coroutine.""" + + request = cloudpickle.loads(message) + + if isinstance(request, RPCGenerateRequest): + return self.generate(identity, request) + + elif isinstance(request, RPCAbortRequest): + return self.abort(identity, request) + + elif isinstance(request, RPCUtilityRequest): + if request == RPCUtilityRequest.GET_MODEL_CONFIG: + return self.get_model_config(identity) + elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: + return self.get_parallel_config(identity) + elif request == RPCUtilityRequest.GET_DECODING_CONFIG: + return self.get_decoding_config(identity) + elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: + return self.get_scheduler_config(identity) + elif request == RPCUtilityRequest.GET_LORA_CONFIG: + return self.get_lora_config(identity) + elif request == RPCUtilityRequest.DO_LOG_STATS: + return self.do_log_stats(identity) + elif request == RPCUtilityRequest.IS_SERVER_READY: + return self.is_server_ready(identity) + elif request == RPCUtilityRequest.CHECK_HEALTH: + return self.check_health(identity) + elif request == RPCUtilityRequest.IS_TRACING_ENABLED: + return self.is_tracing_enabled(identity) + else: + raise ValueError(f"Unknown RPCUtilityRequest type: {request}") # noqa + + else: + raise ValueError(f"Unknown RPCRequest type: {request}") # noqa + + async def run_server_loop(self): + """Inner RPC Server Loop""" + + running_tasks = set() + while True: + # Wait for a request. + identity, message = await self.socket.recv_multipart() + + # Process the request async. + task = asyncio.create_task( + self._make_handler_coro(identity, message)) + + # We need to keep around a strong reference to the task, + # to avoid the task disappearing mid-execution as running tasks + # can be GC'ed. Below is a common "fire-and-forget" tasks + # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + running_tasks.add(task) + task.add_done_callback(running_tasks.discard) + + +async def run_server(server: AsyncEngineRPCServer): + # Put the server task into the asyncio loop. + loop = asyncio.get_running_loop() + server_task = loop.create_task(server.run_server_loop()) + + # Interruption handling. + def signal_handler() -> None: + # Kill the server on interrupt / terminate + server_task.cancel() + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + except asyncio.CancelledError: + logger.info("vLLM ZMQ RPC Server was interrupted.") + finally: + # Clean up all resources. + server.cleanup() + + +def run_rpc_server(async_engine_args: AsyncEngineArgs, + usage_context: UsageContext, port: int, load_in_low_bit: str): + server = AsyncEngineRPCServer(async_engine_args, usage_context, port, load_in_low_bit) + asyncio.run(run_server(server)) diff --git a/python/llm/src/ipex_llm/vllm/xpu/ipex_llm_gpu_executor.py b/python/llm/src/ipex_llm/vllm/xpu/ipex_llm_gpu_executor.py deleted file mode 100644 index bf8c8abe..00000000 --- a/python/llm/src/ipex_llm/vllm/xpu/ipex_llm_gpu_executor.py +++ /dev/null @@ -1,466 +0,0 @@ -import asyncio -import copy -from collections import defaultdict -import os -import pickle -import importlib -from typing import TYPE_CHECKING, Any, Dict, List, Optional - -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig, LoRAConfig) -from vllm.engine.ray_utils import RayWorkerVllm, ray -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.executor.utils import check_block_size_valid -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.utils import (set_cuda_visible_devices, get_ip, get_open_port, - get_distributed_init_method, make_async) -import functools -from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert -from ipex_llm.utils.common import invalidInputError - -if ray is not None: - from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - -if TYPE_CHECKING: - from ray.util.placement_group import PlacementGroup - -logger = init_logger(__name__) - -# A map between the device type (in device config) to its worker module. -DEVICE_TO_WORKER_MODULE_MAP = { - "cuda": "vllm.worker.worker", - "xpu": "vllm.worker.worker", - "neuron": "vllm.worker.neuron_worker", -} - -# If the env var is set, it uses the Ray's compiled DAG API -# which optimizes the control plane overhead. -# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. -USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) - - -class IPEXLLMGPUExecutor(ExecutorBase): - - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - load_in_low_bit: str, - ) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.load_in_low_bit = load_in_low_bit - - invalidInputError(self.parallel_config.worker_use_ray, - "worker_use_ray is False, but use ray worker") - placement_group = self.parallel_config.placement_group - - # Disable Ray usage stats collection. - ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") - if ray_usage != "1": - os.environ["RAY_USAGE_STATS_ENABLED"] = "0" - - # Create the parallel GPU workers. - self._init_workers_ray(placement_group) - - # Profile the memory usage and initialize the cache. - self._init_cache() - - self.forward_dag = None - if USE_RAY_COMPILED_DAG: - self.forward_dag = self._compiled_ray_dag() - - def _dispatch_worker(self): - worker_module = DEVICE_TO_WORKER_MODULE_MAP[ - self.device_config.device_type] - imported_worker = importlib.import_module(worker_module) - Worker = imported_worker.Worker - return Worker - - def _init_workers_ray(self, placement_group: "PlacementGroup", - **ray_remote_kwargs): - if self.parallel_config.tensor_parallel_size == 1: - # For single GPU case, we use a ray worker with constrained memory. - num_gpus = self.cache_config.gpu_memory_utilization - else: - # Otherwise, the ray workers are allocated with a full GPU. - num_gpus = 1 - - # The driver dummy worker does not actually use any resources. - # It holds the resource for the driver worker. - self.driver_dummy_worker: RayWorkerVllm = None - # The remaining workers are the actual ray actors. - self.workers: List[RayWorkerVllm] = [] - - # Create the workers. - driver_ip = get_ip() - for bundle_id, bundle in enumerate(placement_group.bundle_specs): - if not bundle.get("GPU", 0): - continue - scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=placement_group, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=bundle_id, - ) - worker = ray.remote( - num_cpus=0, - num_gpus=num_gpus, - scheduling_strategy=scheduling_strategy, - **ray_remote_kwargs, - )(RayWorkerVllm).remote(self.model_config.trust_remote_code) - - worker_ip = ray.get(worker.get_node_ip.remote()) - if worker_ip == driver_ip and self.driver_dummy_worker is None: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - else: - # Else, added to the list of workers. - self.workers.append(worker) - - if self.driver_dummy_worker is None: - invalidInputError(False, - "Ray does not allocate any GPUs on the driver node. Consider " - "adjusting the Ray placement group or running the driver on a " - "GPU node.") - # Get the set of GPU IDs used on each node. - driver_node_id, driver_gpu_ids = ray.get( - self.driver_dummy_worker.get_node_and_gpu_ids.remote()) - worker_node_and_gpu_ids = ray.get( - [worker.get_node_and_gpu_ids.remote() for worker in self.workers]) - - node_workers = defaultdict(list) - node_gpus = defaultdict(list) - - node_workers[driver_node_id].append(0) - node_gpus[driver_node_id].extend(driver_gpu_ids) - for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, - start=1): - node_workers[node_id].append(i) - node_gpus[node_id].extend(gpu_ids) - for node_id, gpu_ids in node_gpus.items(): - node_gpus[node_id] = sorted(gpu_ids) - - # Set CUDA_VISIBLE_DEVICES for the driver and workers. - set_cuda_visible_devices(node_gpus[driver_node_id]) - for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids): - worker.set_cuda_visible_devices.remote(node_gpus[node_id]) - - distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port()) - - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - Worker = self._dispatch_worker() - - model_config = copy.deepcopy(self.model_config) - parallel_config = copy.deepcopy(self.parallel_config) - scheduler_config = copy.deepcopy(self.scheduler_config) - device_config = copy.deepcopy(self.device_config) - lora_config = copy.deepcopy(self.lora_config) - kv_cache_dtype = self.cache_config.cache_dtype - - # Initialize the actual workers with the Worker class. - for rank, (worker, (node_id, _)) in enumerate( - zip(self.workers, worker_node_and_gpu_ids), - start=1, - ): - local_rank = node_workers[node_id].index(rank) - - def create_worker_function(rank, local_rank, load_in_low_bit): - def worker_function(): - _ipex_llm_convert(load_in_low_bit) - return Worker( - model_config, - parallel_config, - scheduler_config, - device_config, - local_rank, - rank, - distributed_init_method, - lora_config=lora_config, - kv_cache_dtype=kv_cache_dtype, - ) - return worker_function - worker.init_worker.remote(create_worker_function(rank, - local_rank, - self.load_in_low_bit)) - - # Initialize the driver worker with the Worker class. - driver_rank = 0 - driver_local_rank = node_workers[driver_node_id].index(driver_rank) - self.driver_worker = Worker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, - driver_local_rank, - driver_rank, - distributed_init_method, - lora_config=self.lora_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=True, - ) - - # We want to apply patch here before we loading the model - # FIXME(woosuk): We are not properly initializing cupy NCCL when - # we have multiple nodes. - self._run_workers("init_model", - cupy_port=get_open_port() - if not model_config.enforce_eager else None) - self._run_workers( - "load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, - ) - - def _init_cache(self) -> None: - """Profiles the memory usage and initializes the KV cache. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - More details can be found in the - :meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method - from class :class:`~vllm.worker.Worker`. - - Afterwards, as there may be multiple workers, - we take the minimum number of blocks across all workers - to ensure this can be applied to all of them. - - Finally, the engine will initialize the KV cache - with the calculated number of blocks. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers( - "profile_num_available_blocks", - block_size=self.cache_config.block_size, - gpu_memory_utilization=self.cache_config.gpu_memory_utilization, - cpu_swap_space=self.cache_config.swap_space_bytes, - cache_dtype=self.cache_config.cache_dtype, - ) - - # Since we use a shared centralized controller, we take the minimum - # number of blocks across all workers to make sure all the memory - # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) - logger.info(f"# GPU blocks: {num_gpu_blocks}, " - f"# CPU blocks: {num_cpu_blocks}") - - check_block_size_valid(num_gpu_blocks, self.cache_config.block_size, - self.model_config.max_model_len) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - # Initialize the cache. - self._run_workers("init_cache_engine", cache_config=self.cache_config) - # Warm up the model. This includes capturing the model into CUDA graph - # if enforce_eager is False. - self._run_workers("warm_up_model") - - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - all_outputs = self._run_workers( - "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - }, - use_ray_compiled_dag=USE_RAY_COMPILED_DAG) - - # Only the driver worker returns the sampling results. - output = all_outputs[0] - return output - - def add_lora(self, lora_request: LoRARequest) -> bool: - invalidInputError(lora_request.lora_int_id > 0, - "lora_id must be greater than 0.") - return self._run_workers( - "add_lora", - lora_request=lora_request, - ) - - def remove_lora(self, lora_id: int) -> bool: - invalidInputError(lora_id > 0, "lora_id must be greater than 0.") - return self._run_workers( - "remove_lora", - lora_id=lora_id, - ) - - def list_loras(self) -> List[int]: - return self._run_workers("list_loras") - - def _run_workers( - self, - method: str, - *args, - driver_args: Optional[List[Any]]=None, - driver_kwargs: Optional[Dict[str, Any]]=None, - max_concurrent_workers: Optional[int] = None, - use_ray_compiled_dag: bool = False, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - - if max_concurrent_workers: - invalidInputError(False, - "max_concurrent_workers is not supported yet.") - - if use_ray_compiled_dag: - # Right now, compiled DAG can only accept a single - # input. TODO(sang): Fix it. - output_channels = self.forward_dag.execute(1) - else: - # Start the ray workers first. - ray_worker_outputs = [ - worker.execute_method.remote(method, *args, **kwargs) - for worker in self.workers - ] - - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs - - # Start the driver worker after all the ray workers. - driver_worker_output = getattr(self.driver_worker, - method)(*driver_args, **driver_kwargs) - - # Get the results of the ray workers. - if self.workers: - if use_ray_compiled_dag: - try: - ray_worker_outputs = [ - pickle.loads(chan.begin_read()) - for chan in output_channels - ] - finally: - # Has to call end_read in order to reuse the DAG. - for chan in output_channels: - chan.end_read() - else: - ray_worker_outputs = ray.get(ray_worker_outputs) - - return [driver_worker_output] + ray_worker_outputs - - def _compiled_ray_dag(self): - import pkg_resources - required_version = "2.9" - current_version = pkg_resources.get_distribution("ray").version - if current_version < required_version: - invalidInputError(False, - f"Ray version {required_version} or greater is " - f"required, but found {current_version}") - - from ray.dag import MultiOutputNode, InputNode - invalidInputError(self.parallel_config.worker_use_ray, - "Use ray worker, but worker_use_ray is False") - - # Right now, compiled DAG requires at least 1 arg. We send - # a dummy value for now. It will be fixed soon. - with InputNode() as input_data: - forward_dag = MultiOutputNode([ - worker.execute_model_compiled_dag_remote.bind(input_data) - for worker in self.workers - ]) - return forward_dag.experimental_compile() - - def check_health(self) -> None: - """Raises an error if engine is unhealthy.""" - self._check_if_any_actor_is_dead() - - def _check_if_any_actor_is_dead(self): - if not self.workers: - return - - dead_actors = [] - for actor in self.workers: - actor_state = ray.state.actors(actor._ray_actor_id.hex()) - if actor_state["State"] == "DEAD": - dead_actors.append(actor) - if dead_actors: - invalidInputError("At least one Worker is dead. " - f"Dead Workers: {dead_actors}. ") - - -class IPEXLLMGPUExecutorAsync(IPEXLLMGPUExecutor, ExecutorAsyncBase): - - async def _run_workers_async( - self, - method: str, - *args, - driver_args: Optional[List[Any]]=None, - driver_kwargs: Optional[Dict[str, Any]]=None, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - coros = [] - - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs - - # Run the driver worker asynchronously. - driver_executor = make_async(getattr(self.driver_worker, method)) - coros.append(driver_executor(*driver_args, **driver_kwargs)) - - # Run the ray workers asynchronously. - for worker in self.workers: - coros.append(worker.execute_method.remote(method, *args, **kwargs)) - - all_outputs = await asyncio.gather(*coros) - return all_outputs - - async def execute_model_async( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: - all_outputs = await self._run_workers_async( - "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - }) - - # Only the driver worker returns the sampling results. - output = all_outputs[0] - return output - - async def check_health_async(self) -> None: - """Raises an error if engine is unhealthy.""" - self._check_if_any_actor_is_dead() - - -def get_gpu_executor_class(load_in_low_bit): - return functools.partial(IPEXLLMGPUExecutor, load_in_low_bit=load_in_low_bit) - - -def get_gpu_executor_class_async(load_in_low_bit): - return functools.partial(IPEXLLMGPUExecutorAsync, load_in_low_bit=load_in_low_bit) diff --git a/python/llm/src/ipex_llm/vllm/xpu/ipex_llm_wrapper.py b/python/llm/src/ipex_llm/vllm/xpu/ipex_llm_wrapper.py new file mode 100644 index 00000000..47b351d1 --- /dev/null +++ b/python/llm/src/ipex_llm/vllm/xpu/ipex_llm_wrapper.py @@ -0,0 +1,24 @@ +from vllm.logger import init_logger +from vllm.executor.ray_utils import RayWorkerWrapper + + +logger = init_logger(__name__) + + +class IPEXLLMWrapper(RayWorkerWrapper): + def __init__(self, load_in_low_bit="sym_int4", *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert + _ipex_llm_convert(load_in_low_bit=load_in_low_bit) + self.compiled_dag_cuda_device_set = False + + +def get_ipex_llm_wrapper(load_in_low_bit): + # The reason why we not using functools.partial is that + # ray seems not work well with it. + class WrapperWithLoadBit(IPEXLLMWrapper): + def __init__(self, *args, **kwargs) -> None: + super().__init__(load_in_low_bit=load_in_low_bit, *args, **kwargs) + + # a = functools.partial(IPEXLLMWrapper, load_in_low_bit=load_in_low_bit) + return WrapperWithLoadBit diff --git a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py index 065652e7..23e6fd97 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py @@ -14,6 +14,8 @@ # limitations under the License. # import torch +from typing import Optional, Union +from vllm.distributed import tensor_model_parallel_gather, tensor_model_parallel_all_gather from vllm.logger import init_logger from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention, LlamaForCausalLM from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention, Qwen2ForCausalLM @@ -22,236 +24,78 @@ from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention, ChatGLMForCausalLM from vllm.model_executor.model_loader import get_model -from vllm.model_executor.layers.sampler import Sampler - -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.attention import AttentionMetadata from vllm.config import DeviceConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather - -from typing import Tuple, Optional, Union -from ipex_llm.utils.common import invalidInputError -from vllm.sequence import SamplerOutput +from typing import Tuple +from ipex_llm.transformers.low_bit_linear import LowBitLinear -def _Llama_sample( +def _sample_get_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head, hidden_states, - sampling_metadata) - return next_tokens - - -def _Qwen2_sample( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> Optional[SamplerOutput]: - if self.config.tie_word_embeddings: - # Embedding layer is not optimized to LowBitLinear - lm_head_weight = self.model.embed_tokens.weight + lm_head: Union[VocabParallelEmbedding, LowBitLinear], + embedding_bias: Optional[torch.Tensor], +) -> torch.Tensor: + # HINT: we do not support other types of quantization for now + # TODO: we may encounter tie-word-embedding problems + if isinstance(lm_head, VocabParallelEmbedding): + logits = lm_head.linear_method.apply(lm_head, + hidden_states, + bias=embedding_bias) else: - # This layer is optimized to LowBitLinear - lm_head_weight = self.lm_head - next_tokens = self.sampler(lm_head_weight, hidden_states, - sampling_metadata) - return next_tokens - - -def _Chatglm_sample( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.transformer.output_layer, hidden_states, - sampling_metadata) - - return next_tokens - - -def _sample_get_logits(self, hidden_states: torch.Tensor, - embedding: Union[torch.nn.Module, torch.Tensor], - embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: - # For tie_word_embedding models, the embedding is not optimized as - # the low_bit_linear layer... - if isinstance(embedding, torch.Tensor): - logits = torch.matmul(hidden_states, embedding.t()) + logits = lm_head(hidden_states) + if embedding_bias is not None: + logits += embedding_bias + if self.use_gather: + logits = tensor_model_parallel_gather(logits) else: - logits = embedding(hidden_states) - if embedding_bias is not None: - logits += embedding_bias - logits = tensor_model_parallel_gather(logits) - # Remove paddings in vocab (if any). + logits = tensor_model_parallel_all_gather(logits) if logits is not None: - logits = logits[:, :self.org_vocab_size] + logits = logits[:, : self.org_vocab_size] return logits -def _MLP_forward(self, x): - gate_up = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x = self.down_proj(x) - return x - - -def _Attention_forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - input_metadata: InputMetadata, -) -> torch.Tensor: - qkv = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output = self.o_proj(attn_output) - return output - - -def _QWen_Attention_forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor], - input_metadata: InputMetadata, -) -> torch.Tensor: - qkv = self.c_attn(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output = self.c_proj(attn_output) - return output - - -def _QWen_MLP_forward(self, x): - gate_up = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x = self.c_proj(x) - return x - - -def _ChatGLM_MLP_forward(self, hidden_states): - # [s, b, 4hp] - intermediate_parallel = self.dense_h_to_4h(hidden_states) - intermediate_parallel = self.activation_func(intermediate_parallel) - # [s, b, h] - output = self.dense_4h_to_h(intermediate_parallel) - return output - - -def _Baichuan_Attention_forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor], - input_metadata: InputMetadata, -) -> torch.Tensor: - qkv = self.W_pack(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) - if self.postion_embedding != "ALIBI": - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output = self.o_proj(attn_output) - return output - - -def _ChatGLM_Attention_forward( - self, - hidden_states: torch.Tensor, - position_ids: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor], - input_metadata: InputMetadata, -) -> torch.Tensor: - qkv = self.query_key_value(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(position_ids, q, k) - key_cache, value_cache = kv_cache - context_layer = self.attn( - q, - k, - v, - key_cache, - value_cache, - input_metadata, - ) - attn_output = self.dense(context_layer) - return attn_output - -_REPLACED_MLP_LAYERS = { - LlamaMLP: _MLP_forward, - Qwen2MLP: _MLP_forward, - BaiChuanMLP: _MLP_forward, - QWenMLP: _QWen_MLP_forward, - GLMMLP: _ChatGLM_MLP_forward -} - -_REPLACED_ATTENTION_LAYERS = { - LlamaAttention: _Attention_forward, - Qwen2Attention: _Attention_forward, - QWenAttention: _QWen_Attention_forward, - BaiChuanAttention: _Baichuan_Attention_forward, - GLMAttention: _ChatGLM_Attention_forward -} - -_REPLACED_SAMPLER_LAYERS = { - LlamaForCausalLM: _Llama_sample, - QWenLMHeadModel: _Llama_sample, - ChatGLMForCausalLM: _Chatglm_sample, - Qwen2ForCausalLM: _Qwen2_sample, - BaiChuanBaseForCausalLM: _Llama_sample, -} - - -def _model_mlp_convert(): - for module, replaced_func in _REPLACED_MLP_LAYERS.items(): - setattr(module, "forward", replaced_func) - - def _model_sample_convert(): - setattr(Sampler, "_get_logits", _sample_get_logits) - for module, replaced_func in _REPLACED_SAMPLER_LAYERS.items(): - setattr(module, "sample", replaced_func) - - -def _model_attention_convert(): - for module, replaced_func in _REPLACED_ATTENTION_LAYERS.items(): - setattr(module, "forward", replaced_func) + from vllm.model_executor.layers.logits_processor import LogitsProcessor + setattr(LogitsProcessor, "_get_logits", _sample_get_logits) def _ipex_llm_convert(load_in_low_bit): - from vllm.worker.model_runner import ModelRunner - import vllm.model_executor.model_loader as model_loader - setattr(ModelRunner, "load_model", get_load_function(load_in_low_bit)) + from vllm.worker.xpu_model_runner import XPUModelRunner + from ipex_llm.vllm.xpu.ipex_llm_wrapper import get_ipex_llm_wrapper + import vllm.executor.ray_utils as ray_utils + setattr(XPUModelRunner, "load_model", get_load_function(load_in_low_bit)) + setattr(ray_utils, "RayWorkerWrapper", get_ipex_llm_wrapper(load_in_low_bit)) def get_load_function(low_bit): def _ipex_llm_load_model(self) -> None: - # _model_mlp_convert() - # _model_attention_convert() _model_sample_convert() - from vllm.utils import measure_device_memory - with measure_device_memory() as m: - # only support xpu for now - # We have to create a new DeviceConfig. - # Otherwise, will get the wrong xpu memory usage - self.model = get_model(self.model_config, - DeviceConfig("cpu"), - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + # from vllm.utils import measure_device_memory + from vllm.utils import CudaMemoryProfiler + with CudaMemoryProfiler() as m: + self.model = get_model( + model_config=self.model_config, + device_config=DeviceConfig("cpu"), + load_config=self.load_config, + lora_config=self.lora_config, + multimodal_config=self.multimodal_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config, + ) + if "qwen" in self.model_config.model.lower() and \ + self.model.model.layers[0].mlp.down_proj.input_size_per_partition % 256 != 0: + self.model.apply(padding_mlp) from ipex_llm import optimize_model import os not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None) is_glm4_model = "glm-4" in self.model_config.model.lower() - if not_convert_last_mlp is not None or is_glm4_model: + is_codegeex4_model = "codegeex4-all" in self.model_config.model.lower() + if not_convert_last_mlp is not None or is_glm4_model or is_codegeex4_model: # only use to avoid nan value in last mlp forward running glm4-9b-chat modules = ["35.mlp", "36.mlp", "37.mlp", "38.mlp", "39.mlp"] else: @@ -263,22 +107,34 @@ def get_load_function(low_bit): self.model_memory_usage = m.consumed_memory logger = init_logger(__name__) - logger.info(f"Loading model weights took " - f"{self.model_memory_usage / float(2**30):.4f} GB") + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) - if self.lora_config: - invalidInputError(hasattr(self.model, "supported_lora_modules") - and self.model.supported_lora_modules, - "Model does not support LoRA") - invalidInputError(hasattr(self.model, "embedding_modules"), - "Model does not have embedding_modules") - invalidInputError(hasattr(self.model, "embedding_padding_modules"), - "Model does not have embedding_padding_modules") - self.lora_manager = LRUCacheWorkerLoRAManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens + - self.scheduler_config.max_paddings, self.vocab_size, - self.lora_config, self.device, self.model.embedding_modules, - self.model.embedding_padding_modules) - self.model = self.lora_manager.create_lora_manager(self.model) return _ipex_llm_load_model + + +def padding_mlp(module: torch.nn.Module): + if isinstance(module, Qwen2MLP): + hidden_size = module.down_proj.output_size + # devide by rank + intermediate_size = module.down_proj.input_size_per_partition + padding_size = 256 + padding_intermediate_size = \ + (intermediate_size + padding_size - 1) // padding_size * padding_size + if intermediate_size % padding_size == 0: + return + gate_up_weight = module.gate_up_proj.weight.data + new_gate_up_weight = torch.zeros([padding_intermediate_size * 2, hidden_size], + dtype=gate_up_weight.dtype, device=gate_up_weight.device) + # merge_gate_up_weight + new_gate_up_weight[:intermediate_size, :] = gate_up_weight[:intermediate_size, :] + new_gate_up_weight[padding_intermediate_size:padding_intermediate_size+intermediate_size, :] = gate_up_weight[intermediate_size:, :] # noqa + module.gate_up_proj.output_size_per_partition = padding_intermediate_size * 2 + module.gate_up_proj.weight = torch.nn.Parameter(new_gate_up_weight, requires_grad=False) + + down_weight = module.down_proj.weight.data + new_down_weight = torch.zeros([hidden_size, padding_intermediate_size], + dtype=down_weight.dtype, device=down_weight.device) + new_down_weight[:, :intermediate_size] = down_weight + module.down_proj.input_size_per_partition = padding_intermediate_size + module.down_proj.weight = torch.nn.Parameter(new_down_weight, requires_grad=False)