diff --git a/python/llm/example/GPU/vLLM-Serving/README.md b/python/llm/example/GPU/vLLM-Serving/README.md index 0e644d8b..1ef70a39 100644 --- a/python/llm/example/GPU/vLLM-Serving/README.md +++ b/python/llm/example/GPU/vLLM-Serving/README.md @@ -1,8 +1,8 @@ -# vLLM continuous batching on Intel GPUs (experimental support) +# vLLM continuous batching on Intel GPUs This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with IPEX-LLM low-bits optimizations). -The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.6.2). +The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.6.6). Currently, we support the following models for vLLM engine: @@ -10,6 +10,8 @@ Currently, we support the following models for vLLM engine: - Llama series models - ChatGLM series models - Baichuan series models +- Deepseek series models +- Multimodal models ## Example: Serving LLaMA2-7B using Intel GPU @@ -17,7 +19,9 @@ In this example, we will run Llama2-7b model using Arc A770 and provide `OpenAI- ### 0. Environment -To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.1. Please check the requirements at [here](https://www.intel.com/content/www/us/en/docs/oneapi/installation-guide-linux/2024-1/overview.html). +To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2025.0.1. Please check the requirements at [here](https://www.intel.com/content/www/us/en/docs/oneapi/installation-guide-linux/2025-0/overview.html). + +Besides, you may also want to install the latest compute runtime at [here](https://github.com/intel/compute-runtime/releases) After install the toolkit, run the following commands in your environment before starting vLLM GPU: ```bash @@ -26,10 +30,9 @@ source /opt/intel/oneapi/setvars.sh sycl-ls # Example output with one Arc A770: -[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device 1.2 [2023.16.7.0.21_160000] -[opencl:cpu:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i9-13900K 3.0 [2023.16.7.0.21_160000] -[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics 3.0 [23.17.26241.33] -[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26241] +[level_zero:gpu][level_zero:0] Intel(R) oneAPI Unified Runtime over Level-Zero, Intel(R) Arc(TM) A770 Graphics 12.55.8 [1.6.32224.500000] +[opencl:cpu][opencl:0] Intel(R) OpenCL, Intel(R) Xeon(R) w5-3435X OpenCL 3.0 (Build 0) [2024.18.12.0.05_160000] +[opencl:gpu][opencl:1] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [24.52.32224.5] ``` ### 1. Install @@ -43,15 +46,16 @@ source /opt/intel/oneapi/setvars.sh conda create -n ipex-vllm python=3.11 conda activate ipex-vllm # Install dependencies -pip install --pre --upgrade "ipex-llm[xpu]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +pip install --pre --upgrade "ipex-llm[xpu_2.6]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ pip install setuptools-scm pip install --upgrade cmake # cd to your workdir -git clone -b 0.6.2 https://github.com/analytics-zoo/vllm.git +git clone -b 0.6.6 https://github.com/analytics-zoo/vllm.git cd vllm -VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v . +VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm # For Qwen model support pip install transformers_stream_generator einops tiktoken +pip install ray ``` ### 2. Configure recommended environment variables @@ -205,3 +209,35 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \ --distributed-executor-backend ray \ --disable-async-output-proc ``` + +### 4. Load low bit models with vLLM + +To load low-bit model directly with vLLM, we can use the following option `--low-bit-model-path` when starting service or `low_bit_model_path` when using `vllm_offline_inference.py`. + +The low bit model needs to be saved using the `--low-bit-save-path` or `low_bit_save_path` option. + +For instance, to save a FP8 low-bit `DeepSeek-R1-Distill-Qwen-7B` model on disk, we can execute the following python script. + +```python +from vllm import SamplingParams +from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM + +# Create an LLM. +llm = LLM(model="DeepSeek-R1-Distill-Qwen-7B", # Unquantized model path on disk + device="xpu", + dtype="float16", + enforce_eager=True, + load_in_low_bit="sym_int4", # The low-bit you may want to quantized to + tensor_parallel_size=1, # The tp-size you choose needs to be same when you later uses the low-bit model + disable_async_output_proc=True, + distributed_executor_backend="ray", + max_model_len=500, + trust_remote_code=True, + block_size=8, + max_num_batched_tokens=500, + low_bit_save_path="/llm/fp8-model-path") # saved path +``` + +When finish executing, the low-bit model has been saved at `/llm/fp8-model-path`. + +Later we can use the option `--low-bit-model-path /llm/fp8-model-path` to use the low-bit model. diff --git a/python/llm/src/ipex_llm/optimize.py b/python/llm/src/ipex_llm/optimize.py index 0bf1c410..ca910c64 100644 --- a/python/llm/src/ipex_llm/optimize.py +++ b/python/llm/src/ipex_llm/optimize.py @@ -170,9 +170,11 @@ def load_low_bit(model, model_path): invalidInputError(isinstance(model, torch.nn.Module), "model should be an instance of `torch.nn.Module`, " f"but got {type(model)} at last.") - invalidInputError(model.device.type in ('cpu', 'meta'), - "Expect model on device `cpu` or `meta`, " - f"but got device type {model.device.type}") + if hasattr(model, "device"): + # vLLM do not have device for model + invalidInputError(model.device.type in ('cpu', 'meta'), + "Expect model on device `cpu` or `meta`, " + f"but got device type {model.device.type}") qtype = ggml_tensor_qtype[low_bit] model = ggml_convert_low_bit(model, qtype=qtype, convert_shape_only=True) diff --git a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py index e6424493..178da383 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py +++ b/python/llm/src/ipex_llm/vllm/xpu/engine/engine.py @@ -117,7 +117,6 @@ class IPEXLLMClass(LLM): Note: if enforce_eager is unset (enforce_eager is None) it defaults to False. ''' - if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True diff --git a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py index 12963b72..15a9818a 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py @@ -82,35 +82,49 @@ def get_load_function(low_bit): # from vllm.utils import measure_device_memory from vllm.utils import DeviceMemoryProfiler with DeviceMemoryProfiler() as m: + import os from dataclasses import replace new_device_config = DeviceConfig("cpu") new_vllm_config = replace(self.vllm_config, device_config=new_device_config) + # We are loading an low-bit model, where all the optimizations should have been + # applied... + # We can skip the following optimizations self.model = get_model( vllm_config=new_vllm_config ) - if "qwen" in self.vllm_config.model_config.model.lower() or \ - "baichuan" in self.vllm_config.model_config.model.lower() or \ - "codegeex4-all" in self.vllm_config.model_config.model.lower() or \ - "chatglm" in self.vllm_config.model_config.model.lower(): - self.model.apply(padding_mlp) - from ipex_llm import optimize_model - import os - not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None) - if not_convert_last_mlp is not None: - # only use to avoid nan value in last mlp forward running glm4-9b-chat - modules = ["35.mlp", "36.mlp", "37.mlp", "38.mlp", "39.mlp"] - else: - modules = None - if "minicpm" in self.vllm_config.model_config.model.lower(): - modules = ["vpm", "resampler"] - if "internvl2" in self.vllm_config.model_config.model.lower(): - modules = ["vision_model", "mlp1"] - if "deepseek-v2" in self.vllm_config.model_config.model.lower(): - modules = ["down_proj"] - optimize_model(self.model, - low_bit=low_bit, - torch_dtype=self.vllm_config.model_config.dtype, - modules_to_not_convert=modules) + if self.vllm_config.model_config.low_bit_model_path is None: + if "qwen" in self.vllm_config.model_config.model.lower() or \ + "baichuan" in self.vllm_config.model_config.model.lower() or \ + "codegeex4-all" in self.vllm_config.model_config.model.lower() or \ + "chatglm" in self.vllm_config.model_config.model.lower(): + self.model.apply(padding_mlp) + from ipex_llm import optimize_model + not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None) + if not_convert_last_mlp is not None: + # only use to avoid nan value in last mlp forward running glm4-9b-chat + modules = ["35.mlp", "36.mlp", "37.mlp", "38.mlp", "39.mlp"] + else: + modules = None + if "minicpm" in self.vllm_config.model_config.model.lower(): + modules = ["vpm", "resampler"] + if "internvl2" in self.vllm_config.model_config.model.lower(): + modules = ["vision_model", "mlp1"] + if "deepseek-v2" in self.vllm_config.model_config.model.lower(): + modules = ["down_proj"] + optimize_model(self.model, + low_bit=low_bit, + torch_dtype=self.vllm_config.model_config.dtype, + modules_to_not_convert=modules) + # Guancheng: We have to save the model before moving it to the XPU device. + # The `to` method will convert the underlying data. + # Saving it before will help to avoid converting two times. + if self.vllm_config.model_config.low_bit_save_path is not None: + # The local_rank is used for loading models with tensor parallel settings. + local_rank = os.environ["LOCAL_RANK"] + saved_path = os.path.join(self.vllm_config.model_config.low_bit_save_path, + str(local_rank)) + self.model.save_low_bit(saved_path) + self.model = self.model.to(device=self.vllm_config.device_config.device, dtype=self.vllm_config.model_config.dtype)