initial implementation for low_bit_loader vLLM (#12838)
* initial * add logic for handling tensor parallel models * fix * Add some comments * add doc * fix done
This commit is contained in:
parent
c81b7fc003
commit
4eed0c7d99
4 changed files with 88 additions and 37 deletions
|
|
@ -1,8 +1,8 @@
|
||||||
# vLLM continuous batching on Intel GPUs (experimental support)
|
# vLLM continuous batching on Intel GPUs
|
||||||
|
|
||||||
This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with IPEX-LLM low-bits optimizations).
|
This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with IPEX-LLM low-bits optimizations).
|
||||||
|
|
||||||
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.6.2).
|
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.6.6).
|
||||||
|
|
||||||
Currently, we support the following models for vLLM engine:
|
Currently, we support the following models for vLLM engine:
|
||||||
|
|
||||||
|
|
@ -10,6 +10,8 @@ Currently, we support the following models for vLLM engine:
|
||||||
- Llama series models
|
- Llama series models
|
||||||
- ChatGLM series models
|
- ChatGLM series models
|
||||||
- Baichuan series models
|
- Baichuan series models
|
||||||
|
- Deepseek series models
|
||||||
|
- Multimodal models
|
||||||
|
|
||||||
## Example: Serving LLaMA2-7B using Intel GPU
|
## Example: Serving LLaMA2-7B using Intel GPU
|
||||||
|
|
||||||
|
|
@ -17,7 +19,9 @@ In this example, we will run Llama2-7b model using Arc A770 and provide `OpenAI-
|
||||||
|
|
||||||
### 0. Environment
|
### 0. Environment
|
||||||
|
|
||||||
To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.1. Please check the requirements at [here](https://www.intel.com/content/www/us/en/docs/oneapi/installation-guide-linux/2024-1/overview.html).
|
To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2025.0.1. Please check the requirements at [here](https://www.intel.com/content/www/us/en/docs/oneapi/installation-guide-linux/2025-0/overview.html).
|
||||||
|
|
||||||
|
Besides, you may also want to install the latest compute runtime at [here](https://github.com/intel/compute-runtime/releases)
|
||||||
|
|
||||||
After install the toolkit, run the following commands in your environment before starting vLLM GPU:
|
After install the toolkit, run the following commands in your environment before starting vLLM GPU:
|
||||||
```bash
|
```bash
|
||||||
|
|
@ -26,10 +30,9 @@ source /opt/intel/oneapi/setvars.sh
|
||||||
sycl-ls
|
sycl-ls
|
||||||
|
|
||||||
# Example output with one Arc A770:
|
# Example output with one Arc A770:
|
||||||
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device 1.2 [2023.16.7.0.21_160000]
|
[level_zero:gpu][level_zero:0] Intel(R) oneAPI Unified Runtime over Level-Zero, Intel(R) Arc(TM) A770 Graphics 12.55.8 [1.6.32224.500000]
|
||||||
[opencl:cpu:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i9-13900K 3.0 [2023.16.7.0.21_160000]
|
[opencl:cpu][opencl:0] Intel(R) OpenCL, Intel(R) Xeon(R) w5-3435X OpenCL 3.0 (Build 0) [2024.18.12.0.05_160000]
|
||||||
[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics 3.0 [23.17.26241.33]
|
[opencl:gpu][opencl:1] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [24.52.32224.5]
|
||||||
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26241]
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 1. Install
|
### 1. Install
|
||||||
|
|
@ -43,15 +46,16 @@ source /opt/intel/oneapi/setvars.sh
|
||||||
conda create -n ipex-vllm python=3.11
|
conda create -n ipex-vllm python=3.11
|
||||||
conda activate ipex-vllm
|
conda activate ipex-vllm
|
||||||
# Install dependencies
|
# Install dependencies
|
||||||
pip install --pre --upgrade "ipex-llm[xpu]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
pip install --pre --upgrade "ipex-llm[xpu_2.6]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||||
pip install setuptools-scm
|
pip install setuptools-scm
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
# cd to your workdir
|
# cd to your workdir
|
||||||
git clone -b 0.6.2 https://github.com/analytics-zoo/vllm.git
|
git clone -b 0.6.6 https://github.com/analytics-zoo/vllm.git
|
||||||
cd vllm
|
cd vllm
|
||||||
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v .
|
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm
|
||||||
# For Qwen model support
|
# For Qwen model support
|
||||||
pip install transformers_stream_generator einops tiktoken
|
pip install transformers_stream_generator einops tiktoken
|
||||||
|
pip install ray
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Configure recommended environment variables
|
### 2. Configure recommended environment variables
|
||||||
|
|
@ -205,3 +209,35 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
||||||
--distributed-executor-backend ray \
|
--distributed-executor-backend ray \
|
||||||
--disable-async-output-proc
|
--disable-async-output-proc
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 4. Load low bit models with vLLM
|
||||||
|
|
||||||
|
To load low-bit model directly with vLLM, we can use the following option `--low-bit-model-path` when starting service or `low_bit_model_path` when using `vllm_offline_inference.py`.
|
||||||
|
|
||||||
|
The low bit model needs to be saved using the `--low-bit-save-path` or `low_bit_save_path` option.
|
||||||
|
|
||||||
|
For instance, to save a FP8 low-bit `DeepSeek-R1-Distill-Qwen-7B` model on disk, we can execute the following python script.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from vllm import SamplingParams
|
||||||
|
from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
|
||||||
|
|
||||||
|
# Create an LLM.
|
||||||
|
llm = LLM(model="DeepSeek-R1-Distill-Qwen-7B", # Unquantized model path on disk
|
||||||
|
device="xpu",
|
||||||
|
dtype="float16",
|
||||||
|
enforce_eager=True,
|
||||||
|
load_in_low_bit="sym_int4", # The low-bit you may want to quantized to
|
||||||
|
tensor_parallel_size=1, # The tp-size you choose needs to be same when you later uses the low-bit model
|
||||||
|
disable_async_output_proc=True,
|
||||||
|
distributed_executor_backend="ray",
|
||||||
|
max_model_len=500,
|
||||||
|
trust_remote_code=True,
|
||||||
|
block_size=8,
|
||||||
|
max_num_batched_tokens=500,
|
||||||
|
low_bit_save_path="/llm/fp8-model-path") # saved path
|
||||||
|
```
|
||||||
|
|
||||||
|
When finish executing, the low-bit model has been saved at `/llm/fp8-model-path`.
|
||||||
|
|
||||||
|
Later we can use the option `--low-bit-model-path /llm/fp8-model-path` to use the low-bit model.
|
||||||
|
|
|
||||||
|
|
@ -170,6 +170,8 @@ def load_low_bit(model, model_path):
|
||||||
invalidInputError(isinstance(model, torch.nn.Module),
|
invalidInputError(isinstance(model, torch.nn.Module),
|
||||||
"model should be an instance of `torch.nn.Module`, "
|
"model should be an instance of `torch.nn.Module`, "
|
||||||
f"but got {type(model)} at last.")
|
f"but got {type(model)} at last.")
|
||||||
|
if hasattr(model, "device"):
|
||||||
|
# vLLM do not have device for model
|
||||||
invalidInputError(model.device.type in ('cpu', 'meta'),
|
invalidInputError(model.device.type in ('cpu', 'meta'),
|
||||||
"Expect model on device `cpu` or `meta`, "
|
"Expect model on device `cpu` or `meta`, "
|
||||||
f"but got device type {model.device.type}")
|
f"but got device type {model.device.type}")
|
||||||
|
|
|
||||||
|
|
@ -117,7 +117,6 @@ class IPEXLLMClass(LLM):
|
||||||
Note: if enforce_eager is unset (enforce_eager is None)
|
Note: if enforce_eager is unset (enforce_eager is None)
|
||||||
it defaults to False.
|
it defaults to False.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
if "disable_log_stats" not in kwargs:
|
if "disable_log_stats" not in kwargs:
|
||||||
kwargs["disable_log_stats"] = True
|
kwargs["disable_log_stats"] = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -82,19 +82,23 @@ def get_load_function(low_bit):
|
||||||
# from vllm.utils import measure_device_memory
|
# from vllm.utils import measure_device_memory
|
||||||
from vllm.utils import DeviceMemoryProfiler
|
from vllm.utils import DeviceMemoryProfiler
|
||||||
with DeviceMemoryProfiler() as m:
|
with DeviceMemoryProfiler() as m:
|
||||||
|
import os
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
new_device_config = DeviceConfig("cpu")
|
new_device_config = DeviceConfig("cpu")
|
||||||
new_vllm_config = replace(self.vllm_config, device_config=new_device_config)
|
new_vllm_config = replace(self.vllm_config, device_config=new_device_config)
|
||||||
|
# We are loading an low-bit model, where all the optimizations should have been
|
||||||
|
# applied...
|
||||||
|
# We can skip the following optimizations
|
||||||
self.model = get_model(
|
self.model = get_model(
|
||||||
vllm_config=new_vllm_config
|
vllm_config=new_vllm_config
|
||||||
)
|
)
|
||||||
|
if self.vllm_config.model_config.low_bit_model_path is None:
|
||||||
if "qwen" in self.vllm_config.model_config.model.lower() or \
|
if "qwen" in self.vllm_config.model_config.model.lower() or \
|
||||||
"baichuan" in self.vllm_config.model_config.model.lower() or \
|
"baichuan" in self.vllm_config.model_config.model.lower() or \
|
||||||
"codegeex4-all" in self.vllm_config.model_config.model.lower() or \
|
"codegeex4-all" in self.vllm_config.model_config.model.lower() or \
|
||||||
"chatglm" in self.vllm_config.model_config.model.lower():
|
"chatglm" in self.vllm_config.model_config.model.lower():
|
||||||
self.model.apply(padding_mlp)
|
self.model.apply(padding_mlp)
|
||||||
from ipex_llm import optimize_model
|
from ipex_llm import optimize_model
|
||||||
import os
|
|
||||||
not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None)
|
not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None)
|
||||||
if not_convert_last_mlp is not None:
|
if not_convert_last_mlp is not None:
|
||||||
# only use to avoid nan value in last mlp forward running glm4-9b-chat
|
# only use to avoid nan value in last mlp forward running glm4-9b-chat
|
||||||
|
|
@ -111,6 +115,16 @@ def get_load_function(low_bit):
|
||||||
low_bit=low_bit,
|
low_bit=low_bit,
|
||||||
torch_dtype=self.vllm_config.model_config.dtype,
|
torch_dtype=self.vllm_config.model_config.dtype,
|
||||||
modules_to_not_convert=modules)
|
modules_to_not_convert=modules)
|
||||||
|
# Guancheng: We have to save the model before moving it to the XPU device.
|
||||||
|
# The `to` method will convert the underlying data.
|
||||||
|
# Saving it before will help to avoid converting two times.
|
||||||
|
if self.vllm_config.model_config.low_bit_save_path is not None:
|
||||||
|
# The local_rank is used for loading models with tensor parallel settings.
|
||||||
|
local_rank = os.environ["LOCAL_RANK"]
|
||||||
|
saved_path = os.path.join(self.vllm_config.model_config.low_bit_save_path,
|
||||||
|
str(local_rank))
|
||||||
|
self.model.save_low_bit(saved_path)
|
||||||
|
|
||||||
self.model = self.model.to(device=self.vllm_config.device_config.device,
|
self.model = self.model.to(device=self.vllm_config.device_config.device,
|
||||||
dtype=self.vllm_config.model_config.dtype)
|
dtype=self.vllm_config.model_config.dtype)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue