LLM: Fix vLLM CPU model convert mismatch (#11254)
Fix vLLM CPU model convert mismatch.
This commit is contained in:
		
							parent
							
								
									42fab480ea
								
							
						
					
					
						commit
						4b07712fd8
					
				
					 3 changed files with 60 additions and 28 deletions
				
			
		| 
						 | 
				
			
			@ -37,7 +37,7 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
 | 
			
		|||
        engine_args: AsyncEngineArgs,
 | 
			
		||||
        start_engine_loop: bool = True,
 | 
			
		||||
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
 | 
			
		||||
        load_in_low_bit: str = "sym_int4",
 | 
			
		||||
        load_in_low_bit: Optional[str] = None,
 | 
			
		||||
    ) -> "AsyncLLMEngine":
 | 
			
		||||
        """Creates an async LLM engine from the engine arguments."""
 | 
			
		||||
        # Enable ipex-llm optimizations
 | 
			
		||||
| 
						 | 
				
			
			@ -97,7 +97,7 @@ class IPEXLLMClass(LLM):
 | 
			
		|||
        max_context_len_to_capture: Optional[int] = None,
 | 
			
		||||
        max_seq_len_to_capture: int = 8192,
 | 
			
		||||
        disable_custom_all_reduce: bool = False,
 | 
			
		||||
        load_in_low_bit: str = "sym_int4",
 | 
			
		||||
        load_in_low_bit: Optional[str] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        if "disable_log_stats" not in kwargs:
 | 
			
		||||
| 
						 | 
				
			
			@ -136,8 +136,7 @@ class IPEXLLMLLMEngine(LLMEngine):
 | 
			
		|||
        cls,
 | 
			
		||||
        engine_args: EngineArgs,
 | 
			
		||||
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
 | 
			
		||||
        load_in_low_bit: str = "sym_int4",
 | 
			
		||||
        # ipex_llm_optimize_mode: str = 'NATIVE',
 | 
			
		||||
        load_in_low_bit: Optional[str] = None,
 | 
			
		||||
    ) -> "LLMEngine":
 | 
			
		||||
        """Creates an LLM engine from the engine arguments."""
 | 
			
		||||
        # Create the engine configs.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -65,7 +65,7 @@ def parse_args():
 | 
			
		|||
    parser.add_argument(
 | 
			
		||||
        "--load-in-low-bit",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="sym_int4",
 | 
			
		||||
        default=None,
 | 
			
		||||
        help="Low-bit quantization for IPEX-LLM models")
 | 
			
		||||
    return parser.parse_args()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,6 +16,7 @@
 | 
			
		|||
import torch
 | 
			
		||||
from vllm.logger import init_logger
 | 
			
		||||
from vllm.model_executor.model_loader import get_model
 | 
			
		||||
from vllm.model_executor.model_loader.utils import get_model_architecture
 | 
			
		||||
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention
 | 
			
		||||
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention
 | 
			
		||||
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention
 | 
			
		||||
| 
						 | 
				
			
			@ -24,11 +25,14 @@ from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention
 | 
			
		|||
from vllm.attention import Attention, AttentionMetadata
 | 
			
		||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
 | 
			
		||||
from vllm.config import DeviceConfig
 | 
			
		||||
from vllm.logger import init_logger
 | 
			
		||||
 | 
			
		||||
from vllm._C import ops
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from typing import List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
logger = init_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _MLP_forward(self, x):
 | 
			
		||||
    gate_up = self.gate_up_proj(x)
 | 
			
		||||
| 
						 | 
				
			
			@ -59,10 +63,10 @@ def _QWen_Attention_forward(
 | 
			
		|||
    kv_cache: Tuple[torch.Tensor, torch.Tensor],
 | 
			
		||||
    attn_metadata: AttentionMetadata,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    qkv = self.c_attn(hidden_states)
 | 
			
		||||
    qkv = self.c_attn(hidden_states).to(dtype=kv_cache.dtype)
 | 
			
		||||
    q, k, v = qkv.chunk(chunks=3, dim=-1)
 | 
			
		||||
    q, k = self.rotary_emb(positions, q, k)
 | 
			
		||||
    attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale)
 | 
			
		||||
    attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
 | 
			
		||||
    output = self.c_proj(attn_output)
 | 
			
		||||
    return output
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -74,6 +78,21 @@ def _QWen_MLP_forward(self, x):
 | 
			
		|||
    return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _Qwen2_Attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    positions: torch.Tensor,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
    kv_cache: torch.Tensor,
 | 
			
		||||
    attn_metadata: AttentionMetadata,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    qkv = self.qkv_proj(hidden_states).to(dtype=kv_cache.dtype)
 | 
			
		||||
    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
 | 
			
		||||
    q, k = self.rotary_emb(positions, q, k)
 | 
			
		||||
    attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
 | 
			
		||||
    output = self.o_proj(attn_output)
 | 
			
		||||
    return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _ChatGLM_MLP_forward(self, hidden_states):
 | 
			
		||||
    # [s, b, 4hp]
 | 
			
		||||
    intermediate_parallel = self.dense_h_to_4h(hidden_states)
 | 
			
		||||
| 
						 | 
				
			
			@ -90,11 +109,11 @@ def _Baichuan_Attention_forward(
 | 
			
		|||
    kv_cache: Tuple[torch.Tensor, torch.Tensor],
 | 
			
		||||
    attn_metadata: AttentionMetadata,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    qkv = self.W_pack(hidden_states)
 | 
			
		||||
    qkv = self.W_pack(hidden_states).to(dtype=kv_cache.dtype)
 | 
			
		||||
    q, k, v = qkv.chunk(chunks=3, dim=-1)
 | 
			
		||||
    if self.postion_embedding != "ALIBI":
 | 
			
		||||
        q, k = self.rotary_emb(positions, q, k)
 | 
			
		||||
    attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale)
 | 
			
		||||
    attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
 | 
			
		||||
    output = self.o_proj(attn_output)
 | 
			
		||||
    return output
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -106,7 +125,7 @@ def _ChatGLM_Attention_forward(
 | 
			
		|||
    kv_cache: Tuple[torch.Tensor, torch.Tensor],
 | 
			
		||||
    attn_metadata: AttentionMetadata,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    qkv = self.query_key_value(hidden_states)
 | 
			
		||||
    qkv = self.query_key_value(hidden_states).to(dtype=kv_cache.dtype)
 | 
			
		||||
    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
 | 
			
		||||
    q, k = self.rotary_emb(position_ids, q, k)
 | 
			
		||||
    context_layer = self.attn(
 | 
			
		||||
| 
						 | 
				
			
			@ -123,18 +142,25 @@ _REPLACED_MLP_LAYERS = {
 | 
			
		|||
    LlamaMLP: _MLP_forward,
 | 
			
		||||
    Qwen2MLP: _MLP_forward,
 | 
			
		||||
    BaiChuanMLP: _MLP_forward,
 | 
			
		||||
    QWenMLP: _QWen_MLP_forward,
 | 
			
		||||
    # QWenMLP: _QWen_MLP_forward,
 | 
			
		||||
    GLMMLP: _ChatGLM_MLP_forward
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_REPLACED_ATTENTION_LAYERS = {
 | 
			
		||||
    LlamaAttention: _Attention_forward,
 | 
			
		||||
    Qwen2Attention: _Attention_forward,
 | 
			
		||||
    QWenAttention: _QWen_Attention_forward,
 | 
			
		||||
    Qwen2Attention: _Qwen2_Attention_forward,
 | 
			
		||||
    # QWenAttention: _QWen_Attention_forward,
 | 
			
		||||
    BaiChuanAttention: _Baichuan_Attention_forward,
 | 
			
		||||
    GLMAttention: _ChatGLM_Attention_forward
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_IPEX_LLM_SUPPORTED_MODELS = [
 | 
			
		||||
    "LlamaForCausalLM",
 | 
			
		||||
    "BaichuanForCausalLM",
 | 
			
		||||
    "ChatGLMForCausalLM",
 | 
			
		||||
    "Qwen2ForCausalLM",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _model_mlp_convert():
 | 
			
		||||
    for module, replaced_func in _REPLACED_MLP_LAYERS.items():
 | 
			
		||||
| 
						 | 
				
			
			@ -147,6 +173,8 @@ def _model_attention_convert():
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def _ipex_llm_convert(load_in_low_bit):
 | 
			
		||||
    if load_in_low_bit is None:
 | 
			
		||||
        return
 | 
			
		||||
    from vllm.worker.cpu_model_runner import CPUModelRunner
 | 
			
		||||
    import vllm.model_executor.model_loader as model_loader
 | 
			
		||||
    setattr(CPUModelRunner, "load_model", get_load_function(load_in_low_bit))
 | 
			
		||||
| 
						 | 
				
			
			@ -206,6 +234,26 @@ def _ipex_llm_rmsnorm_forward(
 | 
			
		|||
 | 
			
		||||
def get_load_function(low_bit):
 | 
			
		||||
    def _ipex_llm_load_model(self) -> None:
 | 
			
		||||
        model_class = get_model_architecture(self.model_config)[1]
 | 
			
		||||
        cur_model_list = ", ".join(_IPEX_LLM_SUPPORTED_MODELS)
 | 
			
		||||
        if low_bit != "bf16":
 | 
			
		||||
            invalidInputError(model_class in _IPEX_LLM_SUPPORTED_MODELS,
 | 
			
		||||
                              f"Currently IPEX-LLM vLLM convert only support {cur_model_list}.")
 | 
			
		||||
        else:
 | 
			
		||||
            if model_class not in _IPEX_LLM_SUPPORTED_MODELS:
 | 
			
		||||
                logger.warning(
 | 
			
		||||
                    f"Currently IPEX-LLM vLLM convert only support {cur_model_list}."
 | 
			
		||||
                )
 | 
			
		||||
                self.model = get_model(
 | 
			
		||||
                    model_config=self.model_config,
 | 
			
		||||
                    load_config=self.load_config,
 | 
			
		||||
                    device_config=self.device_config,
 | 
			
		||||
                    vision_language_config=self.vision_language_config,
 | 
			
		||||
                    lora_config=self.lora_config,
 | 
			
		||||
                    parallel_config=self.parallel_config,
 | 
			
		||||
                    scheduler_config=self.scheduler_config)
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
        _model_mlp_convert()
 | 
			
		||||
        _model_attention_convert()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -221,19 +269,4 @@ def get_load_function(low_bit):
 | 
			
		|||
        from ipex_llm import optimize_model
 | 
			
		||||
        optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype)
 | 
			
		||||
 | 
			
		||||
        if self.lora_config:
 | 
			
		||||
            invalidInputError(hasattr(self.model, "supported_lora_modules")
 | 
			
		||||
                              and self.model.supported_lora_modules,
 | 
			
		||||
                              "Model does not support LoRA")
 | 
			
		||||
            invalidInputError(hasattr(self.model, "embedding_modules"),
 | 
			
		||||
                              "Model does not have embedding_modules")
 | 
			
		||||
            invalidInputError(hasattr(self.model, "embedding_padding_modules"),
 | 
			
		||||
                              "Model does not have embedding_padding_modules")
 | 
			
		||||
            self.lora_manager = LRUCacheWorkerLoRAManager(
 | 
			
		||||
                self.scheduler_config.max_num_seqs,
 | 
			
		||||
                self.scheduler_config.max_num_batched_tokens +
 | 
			
		||||
                self.scheduler_config.max_paddings, self.vocab_size,
 | 
			
		||||
                self.lora_config, self.device, self.model.embedding_modules,
 | 
			
		||||
                self.model.embedding_padding_modules)
 | 
			
		||||
            self.model = self.lora_manager.create_lora_manager(self.model)
 | 
			
		||||
    return _ipex_llm_load_model
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue