# # Copyright 2016 The BigDL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import torch from vllm.logger import init_logger from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention, LlamaForCausalLM from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention, Qwen2ForCausalLM from vllm.model_executor.models.qwen import QWenMLP, QWenAttention, QWenLMHeadModel from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention, ChatGLMForCausalLM from vllm.model_executor.model_loader import get_model from vllm.model_executor.layers.sampler import Sampler from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor.input_metadata import InputMetadata from vllm.config import DeviceConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather from typing import Tuple, Optional from ipex_llm.utils.common import invalidInputError from vllm.sequence import SamplerOutput def _Llama_sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head, hidden_states, sampling_metadata) return next_tokens def _Qwen2_sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: if self.config.tie_word_embeddings: lm_head_weight = self.model.embed_tokens else: lm_head_weight = self.lm_head next_tokens = self.sampler(lm_head_weight, hidden_states, sampling_metadata) return next_tokens def _Chatglm_sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.transformer.output_layer, hidden_states, sampling_metadata) return next_tokens def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module, embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: logits = embedding(hidden_states) if embedding_bias is not None: logits += embedding_bias logits = tensor_model_parallel_gather(logits) # Remove paddings in vocab (if any). if logits is not None: logits = logits[:, :self.org_vocab_size] return logits def _MLP_forward(self, x): gate_up = self.gate_up_proj(x) x = self.act_fn(gate_up) x = self.down_proj(x) return x def _Attention_forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, input_metadata: InputMetadata, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output = self.o_proj(attn_output) return output def _QWen_Attention_forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], input_metadata: InputMetadata, ) -> torch.Tensor: qkv = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output = self.c_proj(attn_output) return output def _QWen_MLP_forward(self, x): gate_up = self.gate_up_proj(x) x = self.act_fn(gate_up) x = self.c_proj(x) return x def _ChatGLM_MLP_forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel = self.activation_func(intermediate_parallel) # [s, b, h] output = self.dense_4h_to_h(intermediate_parallel) return output def _Baichuan_Attention_forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], input_metadata: InputMetadata, ) -> torch.Tensor: qkv = self.W_pack(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) if self.postion_embedding != "ALIBI": q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output = self.o_proj(attn_output) return output def _ChatGLM_Attention_forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], input_metadata: InputMetadata, ) -> torch.Tensor: qkv = self.query_key_value(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(position_ids, q, k) key_cache, value_cache = kv_cache context_layer = self.attn( q, k, v, key_cache, value_cache, input_metadata, ) attn_output = self.dense(context_layer) return attn_output _REPLACED_MLP_LAYERS = { LlamaMLP: _MLP_forward, Qwen2MLP: _MLP_forward, BaiChuanMLP: _MLP_forward, QWenMLP: _QWen_MLP_forward, GLMMLP: _ChatGLM_MLP_forward } _REPLACED_ATTENTION_LAYERS = { LlamaAttention: _Attention_forward, Qwen2Attention: _Attention_forward, QWenAttention: _QWen_Attention_forward, BaiChuanAttention: _Baichuan_Attention_forward, GLMAttention: _ChatGLM_Attention_forward } _REPLACED_SAMPLER_LAYERS = { LlamaForCausalLM: _Llama_sample, QWenLMHeadModel: _Llama_sample, ChatGLMForCausalLM: _Chatglm_sample, Qwen2ForCausalLM: _Qwen2_sample, BaiChuanBaseForCausalLM: _Llama_sample, } def _model_mlp_convert(): for module, replaced_func in _REPLACED_MLP_LAYERS.items(): setattr(module, "forward", replaced_func) def _model_sample_convert(): setattr(Sampler, "_get_logits", _sample_get_logits) for module, replaced_func in _REPLACED_SAMPLER_LAYERS.items(): setattr(module, "sample", replaced_func) def _model_attention_convert(): for module, replaced_func in _REPLACED_ATTENTION_LAYERS.items(): setattr(module, "forward", replaced_func) def _ipex_llm_convert(load_in_low_bit): from vllm.worker.model_runner import ModelRunner import vllm.model_executor.model_loader as model_loader setattr(ModelRunner, "load_model", get_load_function(load_in_low_bit)) def get_load_function(low_bit): def _ipex_llm_load_model(self) -> None: # _model_mlp_convert() # _model_attention_convert() _model_sample_convert() from vllm.utils import measure_device_memory with measure_device_memory() as m: # only support xpu for now # We have to create a new DeviceConfig. # Otherwise, will get the wrong xpu memory usage self.model = get_model(self.model_config, DeviceConfig("cpu"), lora_config=self.lora_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config) from ipex_llm import optimize_model optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype) self.model = self.model.to(device=self.device_config.device, dtype=self.model_config.dtype) self.model_memory_usage = m.consumed_memory logger = init_logger(__name__) logger.info(f"Loading model weights took " f"{self.model_memory_usage / float(2**30):.4f} GB") if self.lora_config: invalidInputError(hasattr(self.model, "supported_lora_modules") and self.model.supported_lora_modules, "Model does not support LoRA") invalidInputError(hasattr(self.model, "embedding_modules"), "Model does not have embedding_modules") invalidInputError(hasattr(self.model, "embedding_padding_modules"), "Model does not have embedding_padding_modules") self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_paddings, self.vocab_size, self.lora_config, self.device, self.model.embedding_modules, self.model.embedding_padding_modules) self.model = self.lora_manager.create_lora_manager(self.model) return _ipex_llm_load_model