From ac3d53ff5ddba884bbe2a65a888b3d0965ba5b7c Mon Sep 17 00:00:00 2001 From: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com> Date: Tue, 4 Jun 2024 19:10:23 +0800 Subject: [PATCH] LLM: Fix vLLM CPU version error (#11206) Fix vLLM CPU version error --- .../serving/cpu/docker/start-vllm-service.sh | 2 +- .../cpu/docker/vllm_offline_inference.py | 2 +- .../CPU/vLLM-Serving/offline_inference.py | 2 +- .../llm/src/ipex_llm/transformers/convert.py | 29 +++++-- .../src/ipex_llm/vllm/cpu/model_convert.py | 76 ++++++++++++++++--- 5 files changed, 94 insertions(+), 17 deletions(-) diff --git a/docker/llm/serving/cpu/docker/start-vllm-service.sh b/docker/llm/serving/cpu/docker/start-vllm-service.sh index 4f23adfa..b8e442da 100644 --- a/docker/llm/serving/cpu/docker/start-vllm-service.sh +++ b/docker/llm/serving/cpu/docker/start-vllm-service.sh @@ -11,7 +11,7 @@ python -m ipex_llm.vllm.cpu.entrypoints.openai.api_server \ --device cpu \ --dtype bfloat16 \ --enforce-eager \ - --load-in-low-bit sym_int4 \ + --load-in-low-bit bf16 \ --max-model-len 4096 \ --max-num-batched-tokens 10240 \ --max-num-seqs 12 \ diff --git a/docker/llm/serving/cpu/docker/vllm_offline_inference.py b/docker/llm/serving/cpu/docker/vllm_offline_inference.py index 02829a9e..23cee000 100644 --- a/docker/llm/serving/cpu/docker/vllm_offline_inference.py +++ b/docker/llm/serving/cpu/docker/vllm_offline_inference.py @@ -49,7 +49,7 @@ llm = LLM(model="YOUR_MODEL", device="cpu", dtype="bfloat16", enforce_eager=True, - load_in_low_bit="sym_int4", + load_in_low_bit="bf16", tensor_parallel_size=1) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. diff --git a/python/llm/example/CPU/vLLM-Serving/offline_inference.py b/python/llm/example/CPU/vLLM-Serving/offline_inference.py index 142b80d4..63b74f7d 100644 --- a/python/llm/example/CPU/vLLM-Serving/offline_inference.py +++ b/python/llm/example/CPU/vLLM-Serving/offline_inference.py @@ -46,7 +46,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. # llm = LLM(model="facebook/opt-125m") -llm = LLM(model="YOUR_MODEL_PATH", device="cpu", load_in_low_bit="sym_int4") +llm = LLM(model="YOUR_MODEL_PATH", device="cpu", load_in_low_bit="bf16") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 090d534b..6ab8fa7a 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -54,6 +54,7 @@ import sys _IS_VLLM_AVAILABLE = None _USE_VLLM = False +_VLLM_VERSION = None def is_auto_gptq_available(): @@ -77,6 +78,14 @@ def is_vllm_available(): return _IS_VLLM_AVAILABLE +def get_package_version(package_name): + result = subprocess.run(['pip', 'list'], capture_output=True, text=True) + for line in result.stdout.splitlines(): + if line.startswith(package_name): + return line.split()[1] + return None + + def get_use_vllm(): return _USE_VLLM @@ -133,13 +142,24 @@ def is_linear_module(module): is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM) if is_vllm_available(): # Only convert vllm modules + global _VLLM_VERSION + if _VLLM_VERSION is None: + _VLLM_VERSION = get_package_version('vllm') + if 'xpu' in _VLLM_VERSION: + # For vllm xpu + from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_world_size + ) + tp_size = get_tensor_model_parallel_world_size() + else: + # For vllm cpu + tp_size = 1 + from vllm.model_executor.layers.linear import ( ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear ) - from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_group, - get_tensor_model_parallel_world_size - ) + VLLM_LINEAR_LIST = [ ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear ] @@ -148,7 +168,6 @@ def is_linear_module(module): out_features = module.output_size result = True mp_group = None - tp_size = get_tensor_model_parallel_world_size() if isinstance(module, RowParallelLinear) and tp_size >= 2: mp_group = get_tensor_model_parallel_group() in_features = module.input_size_per_partition diff --git a/python/llm/src/ipex_llm/vllm/cpu/model_convert.py b/python/llm/src/ipex_llm/vllm/cpu/model_convert.py index 4228eb06..cf34a63b 100644 --- a/python/llm/src/ipex_llm/vllm/cpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/cpu/model_convert.py @@ -24,8 +24,10 @@ from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention from vllm.attention import Attention, AttentionMetadata from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.config import DeviceConfig -from typing import Tuple + +from vllm._C import ops from ipex_llm.utils.common import invalidInputError +from typing import List, Optional, Tuple, Union def _MLP_forward(self, x): @@ -42,7 +44,7 @@ def _Attention_forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - qkv = self.qkv_proj(hidden_states) + qkv = self.qkv_proj(hidden_states).to(dtype=kv_cache.dtype) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale) @@ -145,9 +147,61 @@ def _model_attention_convert(): def _ipex_llm_convert(load_in_low_bit): - from vllm.worker.model_runner import ModelRunner + from vllm.worker.cpu_model_runner import CPUModelRunner import vllm.model_executor.model_loader as model_loader - setattr(ModelRunner, "load_model", get_load_function(load_in_low_bit)) + setattr(CPUModelRunner, "load_model", get_load_function(load_in_low_bit)) + + from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding + setattr(RotaryEmbedding, "forward", _ipex_llm_rotary_embedding_forward) + from vllm.model_executor.layers.layernorm import RMSNorm + setattr(RMSNorm, "forward", _ipex_llm_rmsnorm_forward) + + +def _ipex_llm_rotary_embedding_forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype) + + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, + self.is_neox_style, self.rotary_dim, + offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) + return query, key + + +def _ipex_llm_rmsnorm_forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + x = x.to(dtype=self.weight.data.dtype) + if residual is not None: + residual = residual.to(dtype=self.weight.data.dtype) + ops.fused_add_rms_norm( + x, + residual, + self.weight.data, + self.variance_epsilon, + ) + return x, residual + out = torch.empty_like(x) + ops.rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + return out def get_load_function(low_bit): @@ -155,11 +209,15 @@ def get_load_function(low_bit): _model_mlp_convert() _model_attention_convert() - self.model = get_model(self.model_config, - self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + vision_language_config=self.vision_language_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + from ipex_llm import optimize_model optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype)