parent
3ef4aa98d1
commit
ac3d53ff5d
5 changed files with 94 additions and 17 deletions
|
|
@ -11,7 +11,7 @@ python -m ipex_llm.vllm.cpu.entrypoints.openai.api_server \
|
|||
--device cpu \
|
||||
--dtype bfloat16 \
|
||||
--enforce-eager \
|
||||
--load-in-low-bit sym_int4 \
|
||||
--load-in-low-bit bf16 \
|
||||
--max-model-len 4096 \
|
||||
--max-num-batched-tokens 10240 \
|
||||
--max-num-seqs 12 \
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ llm = LLM(model="YOUR_MODEL",
|
|||
device="cpu",
|
||||
dtype="bfloat16",
|
||||
enforce_eager=True,
|
||||
load_in_low_bit="sym_int4",
|
||||
load_in_low_bit="bf16",
|
||||
tensor_parallel_size=1)
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|||
|
||||
# Create an LLM.
|
||||
# llm = LLM(model="facebook/opt-125m")
|
||||
llm = LLM(model="YOUR_MODEL_PATH", device="cpu", load_in_low_bit="sym_int4")
|
||||
llm = LLM(model="YOUR_MODEL_PATH", device="cpu", load_in_low_bit="bf16")
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ import sys
|
|||
|
||||
_IS_VLLM_AVAILABLE = None
|
||||
_USE_VLLM = False
|
||||
_VLLM_VERSION = None
|
||||
|
||||
|
||||
def is_auto_gptq_available():
|
||||
|
|
@ -77,6 +78,14 @@ def is_vllm_available():
|
|||
return _IS_VLLM_AVAILABLE
|
||||
|
||||
|
||||
def get_package_version(package_name):
|
||||
result = subprocess.run(['pip', 'list'], capture_output=True, text=True)
|
||||
for line in result.stdout.splitlines():
|
||||
if line.startswith(package_name):
|
||||
return line.split()[1]
|
||||
return None
|
||||
|
||||
|
||||
def get_use_vllm():
|
||||
return _USE_VLLM
|
||||
|
||||
|
|
@ -133,13 +142,24 @@ def is_linear_module(module):
|
|||
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
||||
if is_vllm_available():
|
||||
# Only convert vllm modules
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
|
||||
)
|
||||
global _VLLM_VERSION
|
||||
if _VLLM_VERSION is None:
|
||||
_VLLM_VERSION = get_package_version('vllm')
|
||||
if 'xpu' in _VLLM_VERSION:
|
||||
# For vllm xpu
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_group,
|
||||
get_tensor_model_parallel_world_size
|
||||
)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
else:
|
||||
# For vllm cpu
|
||||
tp_size = 1
|
||||
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
|
||||
)
|
||||
|
||||
VLLM_LINEAR_LIST = [
|
||||
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
|
||||
]
|
||||
|
|
@ -148,7 +168,6 @@ def is_linear_module(module):
|
|||
out_features = module.output_size
|
||||
result = True
|
||||
mp_group = None
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if isinstance(module, RowParallelLinear) and tp_size >= 2:
|
||||
mp_group = get_tensor_model_parallel_group()
|
||||
in_features = module.input_size_per_partition
|
||||
|
|
|
|||
|
|
@ -24,8 +24,10 @@ from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention
|
|||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.config import DeviceConfig
|
||||
from typing import Tuple
|
||||
|
||||
from vllm._C import ops
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
||||
def _MLP_forward(self, x):
|
||||
|
|
@ -42,7 +44,7 @@ def _Attention_forward(
|
|||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = self.qkv_proj(hidden_states).to(dtype=kv_cache.dtype)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale)
|
||||
|
|
@ -145,9 +147,61 @@ def _model_attention_convert():
|
|||
|
||||
|
||||
def _ipex_llm_convert(load_in_low_bit):
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
from vllm.worker.cpu_model_runner import CPUModelRunner
|
||||
import vllm.model_executor.model_loader as model_loader
|
||||
setattr(ModelRunner, "load_model", get_load_function(load_in_low_bit))
|
||||
setattr(CPUModelRunner, "load_model", get_load_function(load_in_low_bit))
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
setattr(RotaryEmbedding, "forward", _ipex_llm_rotary_embedding_forward)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
setattr(RMSNorm, "forward", _ipex_llm_rmsnorm_forward)
|
||||
|
||||
|
||||
def _ipex_llm_rotary_embedding_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
|
||||
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
if offsets is not None:
|
||||
ops.batched_rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style, self.rotary_dim,
|
||||
offsets)
|
||||
else:
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
|
||||
def _ipex_llm_rmsnorm_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
x = x.to(dtype=self.weight.data.dtype)
|
||||
if residual is not None:
|
||||
residual = residual.to(dtype=self.weight.data.dtype)
|
||||
ops.fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
out = torch.empty_like(x)
|
||||
ops.rms_norm(
|
||||
out,
|
||||
x,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def get_load_function(low_bit):
|
||||
|
|
@ -155,11 +209,15 @@ def get_load_function(low_bit):
|
|||
_model_mlp_convert()
|
||||
_model_attention_convert()
|
||||
|
||||
self.model = get_model(self.model_config,
|
||||
self.device_config,
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=self.device_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config)
|
||||
|
||||
from ipex_llm import optimize_model
|
||||
optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue