initial implementation for low_bit_loader vLLM (#12838)
* initial * add logic for handling tensor parallel models * fix * Add some comments * add doc * fix done
This commit is contained in:
		
							parent
							
								
									c81b7fc003
								
							
						
					
					
						commit
						4eed0c7d99
					
				
					 4 changed files with 88 additions and 37 deletions
				
			
		| 
						 | 
				
			
			@ -1,8 +1,8 @@
 | 
			
		|||
# vLLM continuous batching on Intel GPUs (experimental support)
 | 
			
		||||
# vLLM continuous batching on Intel GPUs
 | 
			
		||||
 | 
			
		||||
This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with IPEX-LLM low-bits optimizations).
 | 
			
		||||
 | 
			
		||||
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.6.2).
 | 
			
		||||
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.6.6).
 | 
			
		||||
 | 
			
		||||
Currently, we support the following models for vLLM engine:
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -10,6 +10,8 @@ Currently, we support the following models for vLLM engine:
 | 
			
		|||
- Llama series models
 | 
			
		||||
- ChatGLM series models
 | 
			
		||||
- Baichuan series models
 | 
			
		||||
- Deepseek series models
 | 
			
		||||
- Multimodal models
 | 
			
		||||
 | 
			
		||||
## Example: Serving LLaMA2-7B using Intel GPU
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -17,7 +19,9 @@ In this example, we will run Llama2-7b model using Arc A770 and provide `OpenAI-
 | 
			
		|||
 | 
			
		||||
### 0. Environment
 | 
			
		||||
 | 
			
		||||
To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.1. Please check the requirements at [here](https://www.intel.com/content/www/us/en/docs/oneapi/installation-guide-linux/2024-1/overview.html).
 | 
			
		||||
To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2025.0.1. Please check the requirements at [here](https://www.intel.com/content/www/us/en/docs/oneapi/installation-guide-linux/2025-0/overview.html).
 | 
			
		||||
 | 
			
		||||
Besides, you may also want to install the latest compute runtime at [here](https://github.com/intel/compute-runtime/releases)
 | 
			
		||||
 | 
			
		||||
After install the toolkit, run the following commands in your environment before starting vLLM GPU:
 | 
			
		||||
```bash
 | 
			
		||||
| 
						 | 
				
			
			@ -26,10 +30,9 @@ source /opt/intel/oneapi/setvars.sh
 | 
			
		|||
sycl-ls
 | 
			
		||||
 | 
			
		||||
# Example output with one Arc A770:
 | 
			
		||||
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device 1.2 [2023.16.7.0.21_160000]
 | 
			
		||||
[opencl:cpu:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i9-13900K 3.0 [2023.16.7.0.21_160000]
 | 
			
		||||
[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics 3.0 [23.17.26241.33]
 | 
			
		||||
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26241]
 | 
			
		||||
[level_zero:gpu][level_zero:0] Intel(R) oneAPI Unified Runtime over Level-Zero, Intel(R) Arc(TM) A770 Graphics 12.55.8 [1.6.32224.500000]
 | 
			
		||||
[opencl:cpu][opencl:0] Intel(R) OpenCL, Intel(R) Xeon(R) w5-3435X OpenCL 3.0 (Build 0) [2024.18.12.0.05_160000]
 | 
			
		||||
[opencl:gpu][opencl:1] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO  [24.52.32224.5]
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 1. Install
 | 
			
		||||
| 
						 | 
				
			
			@ -43,15 +46,16 @@ source /opt/intel/oneapi/setvars.sh
 | 
			
		|||
conda create -n ipex-vllm python=3.11
 | 
			
		||||
conda activate ipex-vllm
 | 
			
		||||
# Install dependencies
 | 
			
		||||
pip install --pre --upgrade "ipex-llm[xpu]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
			
		||||
pip install --pre --upgrade "ipex-llm[xpu_2.6]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
			
		||||
pip install setuptools-scm
 | 
			
		||||
pip install --upgrade cmake
 | 
			
		||||
# cd to your workdir
 | 
			
		||||
git clone -b 0.6.2 https://github.com/analytics-zoo/vllm.git
 | 
			
		||||
git clone -b 0.6.6 https://github.com/analytics-zoo/vllm.git
 | 
			
		||||
cd vllm
 | 
			
		||||
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v .
 | 
			
		||||
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm
 | 
			
		||||
# For Qwen model support
 | 
			
		||||
pip install transformers_stream_generator einops tiktoken
 | 
			
		||||
pip install ray
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 2. Configure recommended environment variables
 | 
			
		||||
| 
						 | 
				
			
			@ -205,3 +209,35 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
 | 
			
		|||
  --distributed-executor-backend ray \
 | 
			
		||||
  --disable-async-output-proc
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 4. Load low bit models with vLLM
 | 
			
		||||
 | 
			
		||||
To load low-bit model directly with vLLM, we can use the following option `--low-bit-model-path` when starting service or `low_bit_model_path` when using `vllm_offline_inference.py`.
 | 
			
		||||
 | 
			
		||||
The low bit model needs to be saved using the `--low-bit-save-path` or `low_bit_save_path` option.
 | 
			
		||||
 | 
			
		||||
For instance, to save a FP8 low-bit `DeepSeek-R1-Distill-Qwen-7B` model on disk, we can execute the following python script.
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from vllm import SamplingParams
 | 
			
		||||
from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
 | 
			
		||||
 | 
			
		||||
# Create an LLM.
 | 
			
		||||
llm = LLM(model="DeepSeek-R1-Distill-Qwen-7B", # Unquantized model path on disk
 | 
			
		||||
          device="xpu",
 | 
			
		||||
          dtype="float16",
 | 
			
		||||
          enforce_eager=True,
 | 
			
		||||
          load_in_low_bit="sym_int4",  # The low-bit you may want to quantized to
 | 
			
		||||
          tensor_parallel_size=1,      # The tp-size you choose needs to be same when you later uses the low-bit model
 | 
			
		||||
          disable_async_output_proc=True,
 | 
			
		||||
          distributed_executor_backend="ray",
 | 
			
		||||
          max_model_len=500,
 | 
			
		||||
          trust_remote_code=True,
 | 
			
		||||
          block_size=8,
 | 
			
		||||
          max_num_batched_tokens=500,
 | 
			
		||||
          low_bit_save_path="/llm/fp8-model-path")  # saved path
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
When finish executing, the low-bit model has been saved at `/llm/fp8-model-path`.
 | 
			
		||||
 | 
			
		||||
Later we can use the option `--low-bit-model-path /llm/fp8-model-path` to use the low-bit model.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -170,6 +170,8 @@ def load_low_bit(model, model_path):
 | 
			
		|||
        invalidInputError(isinstance(model, torch.nn.Module),
 | 
			
		||||
                          "model should be an instance of `torch.nn.Module`, "
 | 
			
		||||
                          f"but got {type(model)} at last.")
 | 
			
		||||
        if hasattr(model, "device"):
 | 
			
		||||
            # vLLM do not have device for model
 | 
			
		||||
            invalidInputError(model.device.type in ('cpu', 'meta'),
 | 
			
		||||
                              "Expect model on device `cpu` or `meta`, "
 | 
			
		||||
                              f"but got device type {model.device.type}")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -117,7 +117,6 @@ class IPEXLLMClass(LLM):
 | 
			
		|||
        Note: if enforce_eager is unset (enforce_eager is None)
 | 
			
		||||
        it defaults to False.
 | 
			
		||||
        '''
 | 
			
		||||
 | 
			
		||||
        if "disable_log_stats" not in kwargs:
 | 
			
		||||
            kwargs["disable_log_stats"] = True
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -82,19 +82,23 @@ def get_load_function(low_bit):
 | 
			
		|||
        # from vllm.utils import measure_device_memory
 | 
			
		||||
        from vllm.utils import DeviceMemoryProfiler
 | 
			
		||||
        with DeviceMemoryProfiler() as m:
 | 
			
		||||
            import os
 | 
			
		||||
            from dataclasses import replace
 | 
			
		||||
            new_device_config = DeviceConfig("cpu")
 | 
			
		||||
            new_vllm_config = replace(self.vllm_config, device_config=new_device_config)
 | 
			
		||||
            # We are loading an low-bit model, where all the optimizations should have been
 | 
			
		||||
            # applied...
 | 
			
		||||
            # We can skip the following optimizations
 | 
			
		||||
            self.model = get_model(
 | 
			
		||||
                vllm_config=new_vllm_config
 | 
			
		||||
            )
 | 
			
		||||
            if self.vllm_config.model_config.low_bit_model_path is None:
 | 
			
		||||
                if "qwen" in self.vllm_config.model_config.model.lower() or \
 | 
			
		||||
                        "baichuan" in self.vllm_config.model_config.model.lower() or \
 | 
			
		||||
                        "codegeex4-all" in self.vllm_config.model_config.model.lower() or \
 | 
			
		||||
                        "chatglm" in self.vllm_config.model_config.model.lower():
 | 
			
		||||
                    self.model.apply(padding_mlp)
 | 
			
		||||
                from ipex_llm import optimize_model
 | 
			
		||||
            import os
 | 
			
		||||
                not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None)
 | 
			
		||||
                if not_convert_last_mlp is not None:
 | 
			
		||||
                    # only use to avoid nan value in last mlp forward running glm4-9b-chat
 | 
			
		||||
| 
						 | 
				
			
			@ -111,6 +115,16 @@ def get_load_function(low_bit):
 | 
			
		|||
                               low_bit=low_bit,
 | 
			
		||||
                               torch_dtype=self.vllm_config.model_config.dtype,
 | 
			
		||||
                               modules_to_not_convert=modules)
 | 
			
		||||
            # Guancheng: We have to save the model before moving it to the XPU device.
 | 
			
		||||
            # The `to` method will convert the underlying data.
 | 
			
		||||
            # Saving it before will help to avoid converting two times.
 | 
			
		||||
            if self.vllm_config.model_config.low_bit_save_path is not None:
 | 
			
		||||
                # The local_rank is used for loading models with tensor parallel settings.
 | 
			
		||||
                local_rank = os.environ["LOCAL_RANK"]
 | 
			
		||||
                saved_path = os.path.join(self.vllm_config.model_config.low_bit_save_path,
 | 
			
		||||
                                          str(local_rank))
 | 
			
		||||
                self.model.save_low_bit(saved_path)
 | 
			
		||||
 | 
			
		||||
            self.model = self.model.to(device=self.vllm_config.device_config.device,
 | 
			
		||||
                                       dtype=self.vllm_config.model_config.dtype)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue