LLM: Fix vLLM CPU model convert mismatch (#11254)
Fix vLLM CPU model convert mismatch.
This commit is contained in:
parent
42fab480ea
commit
4b07712fd8
3 changed files with 60 additions and 28 deletions
|
|
@ -37,7 +37,7 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
||||||
engine_args: AsyncEngineArgs,
|
engine_args: AsyncEngineArgs,
|
||||||
start_engine_loop: bool = True,
|
start_engine_loop: bool = True,
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
load_in_low_bit: str = "sym_int4",
|
load_in_low_bit: Optional[str] = None,
|
||||||
) -> "AsyncLLMEngine":
|
) -> "AsyncLLMEngine":
|
||||||
"""Creates an async LLM engine from the engine arguments."""
|
"""Creates an async LLM engine from the engine arguments."""
|
||||||
# Enable ipex-llm optimizations
|
# Enable ipex-llm optimizations
|
||||||
|
|
@ -97,7 +97,7 @@ class IPEXLLMClass(LLM):
|
||||||
max_context_len_to_capture: Optional[int] = None,
|
max_context_len_to_capture: Optional[int] = None,
|
||||||
max_seq_len_to_capture: int = 8192,
|
max_seq_len_to_capture: int = 8192,
|
||||||
disable_custom_all_reduce: bool = False,
|
disable_custom_all_reduce: bool = False,
|
||||||
load_in_low_bit: str = "sym_int4",
|
load_in_low_bit: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
if "disable_log_stats" not in kwargs:
|
if "disable_log_stats" not in kwargs:
|
||||||
|
|
@ -136,8 +136,7 @@ class IPEXLLMLLMEngine(LLMEngine):
|
||||||
cls,
|
cls,
|
||||||
engine_args: EngineArgs,
|
engine_args: EngineArgs,
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
load_in_low_bit: str = "sym_int4",
|
load_in_low_bit: Optional[str] = None,
|
||||||
# ipex_llm_optimize_mode: str = 'NATIVE',
|
|
||||||
) -> "LLMEngine":
|
) -> "LLMEngine":
|
||||||
"""Creates an LLM engine from the engine arguments."""
|
"""Creates an LLM engine from the engine arguments."""
|
||||||
# Create the engine configs.
|
# Create the engine configs.
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ def parse_args():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load-in-low-bit",
|
"--load-in-low-bit",
|
||||||
type=str,
|
type=str,
|
||||||
default="sym_int4",
|
default=None,
|
||||||
help="Low-bit quantization for IPEX-LLM models")
|
help="Low-bit quantization for IPEX-LLM models")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@
|
||||||
import torch
|
import torch
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
|
from vllm.model_executor.model_loader.utils import get_model_architecture
|
||||||
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention
|
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention
|
||||||
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention
|
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention
|
||||||
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention
|
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention
|
||||||
|
|
@ -24,11 +25,14 @@ from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||||
from vllm.config import DeviceConfig
|
from vllm.config import DeviceConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _MLP_forward(self, x):
|
def _MLP_forward(self, x):
|
||||||
gate_up = self.gate_up_proj(x)
|
gate_up = self.gate_up_proj(x)
|
||||||
|
|
@ -59,10 +63,10 @@ def _QWen_Attention_forward(
|
||||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv = self.c_attn(hidden_states)
|
qkv = self.c_attn(hidden_states).to(dtype=kv_cache.dtype)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale)
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
output = self.c_proj(attn_output)
|
output = self.c_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
@ -74,6 +78,21 @@ def _QWen_MLP_forward(self, x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _Qwen2_Attention_forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv = self.qkv_proj(hidden_states).to(dtype=kv_cache.dtype)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
|
output = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _ChatGLM_MLP_forward(self, hidden_states):
|
def _ChatGLM_MLP_forward(self, hidden_states):
|
||||||
# [s, b, 4hp]
|
# [s, b, 4hp]
|
||||||
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
||||||
|
|
@ -90,11 +109,11 @@ def _Baichuan_Attention_forward(
|
||||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv = self.W_pack(hidden_states)
|
qkv = self.W_pack(hidden_states).to(dtype=kv_cache.dtype)
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
if self.postion_embedding != "ALIBI":
|
if self.postion_embedding != "ALIBI":
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale)
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
output = self.o_proj(attn_output)
|
output = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
@ -106,7 +125,7 @@ def _ChatGLM_Attention_forward(
|
||||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states).to(dtype=kv_cache.dtype)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(position_ids, q, k)
|
q, k = self.rotary_emb(position_ids, q, k)
|
||||||
context_layer = self.attn(
|
context_layer = self.attn(
|
||||||
|
|
@ -123,18 +142,25 @@ _REPLACED_MLP_LAYERS = {
|
||||||
LlamaMLP: _MLP_forward,
|
LlamaMLP: _MLP_forward,
|
||||||
Qwen2MLP: _MLP_forward,
|
Qwen2MLP: _MLP_forward,
|
||||||
BaiChuanMLP: _MLP_forward,
|
BaiChuanMLP: _MLP_forward,
|
||||||
QWenMLP: _QWen_MLP_forward,
|
# QWenMLP: _QWen_MLP_forward,
|
||||||
GLMMLP: _ChatGLM_MLP_forward
|
GLMMLP: _ChatGLM_MLP_forward
|
||||||
}
|
}
|
||||||
|
|
||||||
_REPLACED_ATTENTION_LAYERS = {
|
_REPLACED_ATTENTION_LAYERS = {
|
||||||
LlamaAttention: _Attention_forward,
|
LlamaAttention: _Attention_forward,
|
||||||
Qwen2Attention: _Attention_forward,
|
Qwen2Attention: _Qwen2_Attention_forward,
|
||||||
QWenAttention: _QWen_Attention_forward,
|
# QWenAttention: _QWen_Attention_forward,
|
||||||
BaiChuanAttention: _Baichuan_Attention_forward,
|
BaiChuanAttention: _Baichuan_Attention_forward,
|
||||||
GLMAttention: _ChatGLM_Attention_forward
|
GLMAttention: _ChatGLM_Attention_forward
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_IPEX_LLM_SUPPORTED_MODELS = [
|
||||||
|
"LlamaForCausalLM",
|
||||||
|
"BaichuanForCausalLM",
|
||||||
|
"ChatGLMForCausalLM",
|
||||||
|
"Qwen2ForCausalLM",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _model_mlp_convert():
|
def _model_mlp_convert():
|
||||||
for module, replaced_func in _REPLACED_MLP_LAYERS.items():
|
for module, replaced_func in _REPLACED_MLP_LAYERS.items():
|
||||||
|
|
@ -147,6 +173,8 @@ def _model_attention_convert():
|
||||||
|
|
||||||
|
|
||||||
def _ipex_llm_convert(load_in_low_bit):
|
def _ipex_llm_convert(load_in_low_bit):
|
||||||
|
if load_in_low_bit is None:
|
||||||
|
return
|
||||||
from vllm.worker.cpu_model_runner import CPUModelRunner
|
from vllm.worker.cpu_model_runner import CPUModelRunner
|
||||||
import vllm.model_executor.model_loader as model_loader
|
import vllm.model_executor.model_loader as model_loader
|
||||||
setattr(CPUModelRunner, "load_model", get_load_function(load_in_low_bit))
|
setattr(CPUModelRunner, "load_model", get_load_function(load_in_low_bit))
|
||||||
|
|
@ -206,6 +234,26 @@ def _ipex_llm_rmsnorm_forward(
|
||||||
|
|
||||||
def get_load_function(low_bit):
|
def get_load_function(low_bit):
|
||||||
def _ipex_llm_load_model(self) -> None:
|
def _ipex_llm_load_model(self) -> None:
|
||||||
|
model_class = get_model_architecture(self.model_config)[1]
|
||||||
|
cur_model_list = ", ".join(_IPEX_LLM_SUPPORTED_MODELS)
|
||||||
|
if low_bit != "bf16":
|
||||||
|
invalidInputError(model_class in _IPEX_LLM_SUPPORTED_MODELS,
|
||||||
|
f"Currently IPEX-LLM vLLM convert only support {cur_model_list}.")
|
||||||
|
else:
|
||||||
|
if model_class not in _IPEX_LLM_SUPPORTED_MODELS:
|
||||||
|
logger.warning(
|
||||||
|
f"Currently IPEX-LLM vLLM convert only support {cur_model_list}."
|
||||||
|
)
|
||||||
|
self.model = get_model(
|
||||||
|
model_config=self.model_config,
|
||||||
|
load_config=self.load_config,
|
||||||
|
device_config=self.device_config,
|
||||||
|
vision_language_config=self.vision_language_config,
|
||||||
|
lora_config=self.lora_config,
|
||||||
|
parallel_config=self.parallel_config,
|
||||||
|
scheduler_config=self.scheduler_config)
|
||||||
|
return
|
||||||
|
|
||||||
_model_mlp_convert()
|
_model_mlp_convert()
|
||||||
_model_attention_convert()
|
_model_attention_convert()
|
||||||
|
|
||||||
|
|
@ -221,19 +269,4 @@ def get_load_function(low_bit):
|
||||||
from ipex_llm import optimize_model
|
from ipex_llm import optimize_model
|
||||||
optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype)
|
optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype)
|
||||||
|
|
||||||
if self.lora_config:
|
|
||||||
invalidInputError(hasattr(self.model, "supported_lora_modules")
|
|
||||||
and self.model.supported_lora_modules,
|
|
||||||
"Model does not support LoRA")
|
|
||||||
invalidInputError(hasattr(self.model, "embedding_modules"),
|
|
||||||
"Model does not have embedding_modules")
|
|
||||||
invalidInputError(hasattr(self.model, "embedding_padding_modules"),
|
|
||||||
"Model does not have embedding_padding_modules")
|
|
||||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
|
||||||
self.scheduler_config.max_num_seqs,
|
|
||||||
self.scheduler_config.max_num_batched_tokens +
|
|
||||||
self.scheduler_config.max_paddings, self.vocab_size,
|
|
||||||
self.lora_config, self.device, self.model.embedding_modules,
|
|
||||||
self.model.embedding_padding_modules)
|
|
||||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
|
||||||
return _ipex_llm_load_model
|
return _ipex_llm_load_model
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue