LLM: Fix vLLM CPU version error (#11206)

Fix vLLM CPU version error
This commit is contained in:
Xiangyu Tian 2024-06-04 19:10:23 +08:00 committed by GitHub
parent 3ef4aa98d1
commit ac3d53ff5d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 94 additions and 17 deletions

View file

@ -11,7 +11,7 @@ python -m ipex_llm.vllm.cpu.entrypoints.openai.api_server \
--device cpu \
--dtype bfloat16 \
--enforce-eager \
--load-in-low-bit sym_int4 \
--load-in-low-bit bf16 \
--max-model-len 4096 \
--max-num-batched-tokens 10240 \
--max-num-seqs 12 \

View file

@ -49,7 +49,7 @@ llm = LLM(model="YOUR_MODEL",
device="cpu",
dtype="bfloat16",
enforce_eager=True,
load_in_low_bit="sym_int4",
load_in_low_bit="bf16",
tensor_parallel_size=1)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.

View file

@ -46,7 +46,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
# llm = LLM(model="facebook/opt-125m")
llm = LLM(model="YOUR_MODEL_PATH", device="cpu", load_in_low_bit="sym_int4")
llm = LLM(model="YOUR_MODEL_PATH", device="cpu", load_in_low_bit="bf16")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)

View file

@ -54,6 +54,7 @@ import sys
_IS_VLLM_AVAILABLE = None
_USE_VLLM = False
_VLLM_VERSION = None
def is_auto_gptq_available():
@ -77,6 +78,14 @@ def is_vllm_available():
return _IS_VLLM_AVAILABLE
def get_package_version(package_name):
result = subprocess.run(['pip', 'list'], capture_output=True, text=True)
for line in result.stdout.splitlines():
if line.startswith(package_name):
return line.split()[1]
return None
def get_use_vllm():
return _USE_VLLM
@ -133,13 +142,24 @@ def is_linear_module(module):
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
if is_vllm_available():
# Only convert vllm modules
from vllm.model_executor.layers.linear import (
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
)
global _VLLM_VERSION
if _VLLM_VERSION is None:
_VLLM_VERSION = get_package_version('vllm')
if 'xpu' in _VLLM_VERSION:
# For vllm xpu
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size
)
tp_size = get_tensor_model_parallel_world_size()
else:
# For vllm cpu
tp_size = 1
from vllm.model_executor.layers.linear import (
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
)
VLLM_LINEAR_LIST = [
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
]
@ -148,7 +168,6 @@ def is_linear_module(module):
out_features = module.output_size
result = True
mp_group = None
tp_size = get_tensor_model_parallel_world_size()
if isinstance(module, RowParallelLinear) and tp_size >= 2:
mp_group = get_tensor_model_parallel_group()
in_features = module.input_size_per_partition

View file

@ -24,8 +24,10 @@ from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention
from vllm.attention import Attention, AttentionMetadata
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.config import DeviceConfig
from typing import Tuple
from vllm._C import ops
from ipex_llm.utils.common import invalidInputError
from typing import List, Optional, Tuple, Union
def _MLP_forward(self, x):
@ -42,7 +44,7 @@ def _Attention_forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
qkv = self.qkv_proj(hidden_states).to(dtype=kv_cache.dtype)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale)
@ -145,9 +147,61 @@ def _model_attention_convert():
def _ipex_llm_convert(load_in_low_bit):
from vllm.worker.model_runner import ModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner
import vllm.model_executor.model_loader as model_loader
setattr(ModelRunner, "load_model", get_load_function(load_in_low_bit))
setattr(CPUModelRunner, "load_model", get_load_function(load_in_low_bit))
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
setattr(RotaryEmbedding, "forward", _ipex_llm_rotary_embedding_forward)
from vllm.model_executor.layers.layernorm import RMSNorm
setattr(RMSNorm, "forward", _ipex_llm_rmsnorm_forward)
def _ipex_llm_rotary_embedding_forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache,
self.is_neox_style, self.rotary_dim,
offsets)
else:
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key
def _ipex_llm_rmsnorm_forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
x = x.to(dtype=self.weight.data.dtype)
if residual is not None:
residual = residual.to(dtype=self.weight.data.dtype)
ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
ops.rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out
def get_load_function(low_bit):
@ -155,11 +209,15 @@ def get_load_function(low_bit):
_model_mlp_convert()
_model_attention_convert()
self.model = get_model(self.model_config,
self.device_config,
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
vision_language_config=self.vision_language_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
from ipex_llm import optimize_model
optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype)