initial implementation for low_bit_loader vLLM (#12838)

* initial

* add logic for handling tensor parallel models

* fix

* Add some comments

* add doc

* fix done
This commit is contained in:
Guancheng Fu 2025-02-19 19:45:34 +08:00 committed by GitHub
parent c81b7fc003
commit 4eed0c7d99
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 88 additions and 37 deletions

View file

@ -1,8 +1,8 @@
# vLLM continuous batching on Intel GPUs (experimental support)
# vLLM continuous batching on Intel GPUs
This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with IPEX-LLM low-bits optimizations).
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.6.2).
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.6.6).
Currently, we support the following models for vLLM engine:
@ -10,6 +10,8 @@ Currently, we support the following models for vLLM engine:
- Llama series models
- ChatGLM series models
- Baichuan series models
- Deepseek series models
- Multimodal models
## Example: Serving LLaMA2-7B using Intel GPU
@ -17,7 +19,9 @@ In this example, we will run Llama2-7b model using Arc A770 and provide `OpenAI-
### 0. Environment
To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.1. Please check the requirements at [here](https://www.intel.com/content/www/us/en/docs/oneapi/installation-guide-linux/2024-1/overview.html).
To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2025.0.1. Please check the requirements at [here](https://www.intel.com/content/www/us/en/docs/oneapi/installation-guide-linux/2025-0/overview.html).
Besides, you may also want to install the latest compute runtime at [here](https://github.com/intel/compute-runtime/releases)
After install the toolkit, run the following commands in your environment before starting vLLM GPU:
```bash
@ -26,10 +30,9 @@ source /opt/intel/oneapi/setvars.sh
sycl-ls
# Example output with one Arc A770:
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device 1.2 [2023.16.7.0.21_160000]
[opencl:cpu:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i9-13900K 3.0 [2023.16.7.0.21_160000]
[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics 3.0 [23.17.26241.33]
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26241]
[level_zero:gpu][level_zero:0] Intel(R) oneAPI Unified Runtime over Level-Zero, Intel(R) Arc(TM) A770 Graphics 12.55.8 [1.6.32224.500000]
[opencl:cpu][opencl:0] Intel(R) OpenCL, Intel(R) Xeon(R) w5-3435X OpenCL 3.0 (Build 0) [2024.18.12.0.05_160000]
[opencl:gpu][opencl:1] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [24.52.32224.5]
```
### 1. Install
@ -43,15 +46,16 @@ source /opt/intel/oneapi/setvars.sh
conda create -n ipex-vllm python=3.11
conda activate ipex-vllm
# Install dependencies
pip install --pre --upgrade "ipex-llm[xpu]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install --pre --upgrade "ipex-llm[xpu_2.6]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install setuptools-scm
pip install --upgrade cmake
# cd to your workdir
git clone -b 0.6.2 https://github.com/analytics-zoo/vllm.git
git clone -b 0.6.6 https://github.com/analytics-zoo/vllm.git
cd vllm
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v .
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm
# For Qwen model support
pip install transformers_stream_generator einops tiktoken
pip install ray
```
### 2. Configure recommended environment variables
@ -205,3 +209,35 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
--distributed-executor-backend ray \
--disable-async-output-proc
```
### 4. Load low bit models with vLLM
To load low-bit model directly with vLLM, we can use the following option `--low-bit-model-path` when starting service or `low_bit_model_path` when using `vllm_offline_inference.py`.
The low bit model needs to be saved using the `--low-bit-save-path` or `low_bit_save_path` option.
For instance, to save a FP8 low-bit `DeepSeek-R1-Distill-Qwen-7B` model on disk, we can execute the following python script.
```python
from vllm import SamplingParams
from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
# Create an LLM.
llm = LLM(model="DeepSeek-R1-Distill-Qwen-7B", # Unquantized model path on disk
device="xpu",
dtype="float16",
enforce_eager=True,
load_in_low_bit="sym_int4", # The low-bit you may want to quantized to
tensor_parallel_size=1, # The tp-size you choose needs to be same when you later uses the low-bit model
disable_async_output_proc=True,
distributed_executor_backend="ray",
max_model_len=500,
trust_remote_code=True,
block_size=8,
max_num_batched_tokens=500,
low_bit_save_path="/llm/fp8-model-path") # saved path
```
When finish executing, the low-bit model has been saved at `/llm/fp8-model-path`.
Later we can use the option `--low-bit-model-path /llm/fp8-model-path` to use the low-bit model.

View file

@ -170,9 +170,11 @@ def load_low_bit(model, model_path):
invalidInputError(isinstance(model, torch.nn.Module),
"model should be an instance of `torch.nn.Module`, "
f"but got {type(model)} at last.")
invalidInputError(model.device.type in ('cpu', 'meta'),
"Expect model on device `cpu` or `meta`, "
f"but got device type {model.device.type}")
if hasattr(model, "device"):
# vLLM do not have device for model
invalidInputError(model.device.type in ('cpu', 'meta'),
"Expect model on device `cpu` or `meta`, "
f"but got device type {model.device.type}")
qtype = ggml_tensor_qtype[low_bit]
model = ggml_convert_low_bit(model, qtype=qtype, convert_shape_only=True)

View file

@ -117,7 +117,6 @@ class IPEXLLMClass(LLM):
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False.
'''
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True

View file

@ -82,35 +82,49 @@ def get_load_function(low_bit):
# from vllm.utils import measure_device_memory
from vllm.utils import DeviceMemoryProfiler
with DeviceMemoryProfiler() as m:
import os
from dataclasses import replace
new_device_config = DeviceConfig("cpu")
new_vllm_config = replace(self.vllm_config, device_config=new_device_config)
# We are loading an low-bit model, where all the optimizations should have been
# applied...
# We can skip the following optimizations
self.model = get_model(
vllm_config=new_vllm_config
)
if "qwen" in self.vllm_config.model_config.model.lower() or \
"baichuan" in self.vllm_config.model_config.model.lower() or \
"codegeex4-all" in self.vllm_config.model_config.model.lower() or \
"chatglm" in self.vllm_config.model_config.model.lower():
self.model.apply(padding_mlp)
from ipex_llm import optimize_model
import os
not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None)
if not_convert_last_mlp is not None:
# only use to avoid nan value in last mlp forward running glm4-9b-chat
modules = ["35.mlp", "36.mlp", "37.mlp", "38.mlp", "39.mlp"]
else:
modules = None
if "minicpm" in self.vllm_config.model_config.model.lower():
modules = ["vpm", "resampler"]
if "internvl2" in self.vllm_config.model_config.model.lower():
modules = ["vision_model", "mlp1"]
if "deepseek-v2" in self.vllm_config.model_config.model.lower():
modules = ["down_proj"]
optimize_model(self.model,
low_bit=low_bit,
torch_dtype=self.vllm_config.model_config.dtype,
modules_to_not_convert=modules)
if self.vllm_config.model_config.low_bit_model_path is None:
if "qwen" in self.vllm_config.model_config.model.lower() or \
"baichuan" in self.vllm_config.model_config.model.lower() or \
"codegeex4-all" in self.vllm_config.model_config.model.lower() or \
"chatglm" in self.vllm_config.model_config.model.lower():
self.model.apply(padding_mlp)
from ipex_llm import optimize_model
not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None)
if not_convert_last_mlp is not None:
# only use to avoid nan value in last mlp forward running glm4-9b-chat
modules = ["35.mlp", "36.mlp", "37.mlp", "38.mlp", "39.mlp"]
else:
modules = None
if "minicpm" in self.vllm_config.model_config.model.lower():
modules = ["vpm", "resampler"]
if "internvl2" in self.vllm_config.model_config.model.lower():
modules = ["vision_model", "mlp1"]
if "deepseek-v2" in self.vllm_config.model_config.model.lower():
modules = ["down_proj"]
optimize_model(self.model,
low_bit=low_bit,
torch_dtype=self.vllm_config.model_config.dtype,
modules_to_not_convert=modules)
# Guancheng: We have to save the model before moving it to the XPU device.
# The `to` method will convert the underlying data.
# Saving it before will help to avoid converting two times.
if self.vllm_config.model_config.low_bit_save_path is not None:
# The local_rank is used for loading models with tensor parallel settings.
local_rank = os.environ["LOCAL_RANK"]
saved_path = os.path.join(self.vllm_config.model_config.low_bit_save_path,
str(local_rank))
self.model.save_low_bit(saved_path)
self.model = self.model.to(device=self.vllm_config.device_config.device,
dtype=self.vllm_config.model_config.dtype)