LLM: Fix vLLM CPU model convert mismatch (#11254)
Fix vLLM CPU model convert mismatch.
This commit is contained in:
parent
42fab480ea
commit
4b07712fd8
3 changed files with 60 additions and 28 deletions
|
|
@ -37,7 +37,7 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
|||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
load_in_low_bit: Optional[str] = None,
|
||||
) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Enable ipex-llm optimizations
|
||||
|
|
@ -97,7 +97,7 @@ class IPEXLLMClass(LLM):
|
|||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
load_in_low_bit: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
|
|
@ -136,8 +136,7 @@ class IPEXLLMLLMEngine(LLMEngine):
|
|||
cls,
|
||||
engine_args: EngineArgs,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
# ipex_llm_optimize_mode: str = 'NATIVE',
|
||||
load_in_low_bit: Optional[str] = None,
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
"--load-in-low-bit",
|
||||
type=str,
|
||||
default="sym_int4",
|
||||
default=None,
|
||||
help="Low-bit quantization for IPEX-LLM models")
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
import torch
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.model_loader.utils import get_model_architecture
|
||||
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention
|
||||
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention
|
||||
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention
|
||||
|
|
@ -24,11 +25,14 @@ from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention
|
|||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.config import DeviceConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm._C import ops
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _MLP_forward(self, x):
|
||||
gate_up = self.gate_up_proj(x)
|
||||
|
|
@ -59,10 +63,10 @@ def _QWen_Attention_forward(
|
|||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv = self.c_attn(hidden_states)
|
||||
qkv = self.c_attn(hidden_states).to(dtype=kv_cache.dtype)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output = self.c_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
|
@ -74,6 +78,21 @@ def _QWen_MLP_forward(self, x):
|
|||
return x
|
||||
|
||||
|
||||
def _Qwen2_Attention_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv = self.qkv_proj(hidden_states).to(dtype=kv_cache.dtype)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
def _ChatGLM_MLP_forward(self, hidden_states):
|
||||
# [s, b, 4hp]
|
||||
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
||||
|
|
@ -90,11 +109,11 @@ def _Baichuan_Attention_forward(
|
|||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv = self.W_pack(hidden_states)
|
||||
qkv = self.W_pack(hidden_states).to(dtype=kv_cache.dtype)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
if self.postion_embedding != "ALIBI":
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
|
@ -106,7 +125,7 @@ def _ChatGLM_Attention_forward(
|
|||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = self.query_key_value(hidden_states).to(dtype=kv_cache.dtype)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
context_layer = self.attn(
|
||||
|
|
@ -123,18 +142,25 @@ _REPLACED_MLP_LAYERS = {
|
|||
LlamaMLP: _MLP_forward,
|
||||
Qwen2MLP: _MLP_forward,
|
||||
BaiChuanMLP: _MLP_forward,
|
||||
QWenMLP: _QWen_MLP_forward,
|
||||
# QWenMLP: _QWen_MLP_forward,
|
||||
GLMMLP: _ChatGLM_MLP_forward
|
||||
}
|
||||
|
||||
_REPLACED_ATTENTION_LAYERS = {
|
||||
LlamaAttention: _Attention_forward,
|
||||
Qwen2Attention: _Attention_forward,
|
||||
QWenAttention: _QWen_Attention_forward,
|
||||
Qwen2Attention: _Qwen2_Attention_forward,
|
||||
# QWenAttention: _QWen_Attention_forward,
|
||||
BaiChuanAttention: _Baichuan_Attention_forward,
|
||||
GLMAttention: _ChatGLM_Attention_forward
|
||||
}
|
||||
|
||||
_IPEX_LLM_SUPPORTED_MODELS = [
|
||||
"LlamaForCausalLM",
|
||||
"BaichuanForCausalLM",
|
||||
"ChatGLMForCausalLM",
|
||||
"Qwen2ForCausalLM",
|
||||
]
|
||||
|
||||
|
||||
def _model_mlp_convert():
|
||||
for module, replaced_func in _REPLACED_MLP_LAYERS.items():
|
||||
|
|
@ -147,6 +173,8 @@ def _model_attention_convert():
|
|||
|
||||
|
||||
def _ipex_llm_convert(load_in_low_bit):
|
||||
if load_in_low_bit is None:
|
||||
return
|
||||
from vllm.worker.cpu_model_runner import CPUModelRunner
|
||||
import vllm.model_executor.model_loader as model_loader
|
||||
setattr(CPUModelRunner, "load_model", get_load_function(load_in_low_bit))
|
||||
|
|
@ -206,6 +234,26 @@ def _ipex_llm_rmsnorm_forward(
|
|||
|
||||
def get_load_function(low_bit):
|
||||
def _ipex_llm_load_model(self) -> None:
|
||||
model_class = get_model_architecture(self.model_config)[1]
|
||||
cur_model_list = ", ".join(_IPEX_LLM_SUPPORTED_MODELS)
|
||||
if low_bit != "bf16":
|
||||
invalidInputError(model_class in _IPEX_LLM_SUPPORTED_MODELS,
|
||||
f"Currently IPEX-LLM vLLM convert only support {cur_model_list}.")
|
||||
else:
|
||||
if model_class not in _IPEX_LLM_SUPPORTED_MODELS:
|
||||
logger.warning(
|
||||
f"Currently IPEX-LLM vLLM convert only support {cur_model_list}."
|
||||
)
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=self.device_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config)
|
||||
return
|
||||
|
||||
_model_mlp_convert()
|
||||
_model_attention_convert()
|
||||
|
||||
|
|
@ -221,19 +269,4 @@ def get_load_function(low_bit):
|
|||
from ipex_llm import optimize_model
|
||||
optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype)
|
||||
|
||||
if self.lora_config:
|
||||
invalidInputError(hasattr(self.model, "supported_lora_modules")
|
||||
and self.model.supported_lora_modules,
|
||||
"Model does not support LoRA")
|
||||
invalidInputError(hasattr(self.model, "embedding_modules"),
|
||||
"Model does not have embedding_modules")
|
||||
invalidInputError(hasattr(self.model, "embedding_padding_modules"),
|
||||
"Model does not have embedding_padding_modules")
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens +
|
||||
self.scheduler_config.max_paddings, self.vocab_size,
|
||||
self.lora_config, self.device, self.model.embedding_modules,
|
||||
self.model.embedding_padding_modules)
|
||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
return _ipex_llm_load_model
|
||||
|
|
|
|||
Loading…
Reference in a new issue