267 lines
9.2 KiB
Python
267 lines
9.2 KiB
Python
#
|
|
# Copyright 2016 The BigDL Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import torch
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention, LlamaForCausalLM
|
|
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention, Qwen2ForCausalLM
|
|
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention, QWenLMHeadModel
|
|
from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention
|
|
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
|
from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention, ChatGLMForCausalLM
|
|
from vllm.model_executor.model_loader import get_model
|
|
from vllm.model_executor.layers.sampler import Sampler
|
|
|
|
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
|
from vllm.model_executor.input_metadata import InputMetadata
|
|
from vllm.config import DeviceConfig
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather
|
|
|
|
from typing import Tuple, Optional
|
|
from ipex_llm.utils.common import invalidInputError
|
|
from vllm.sequence import SamplerOutput
|
|
|
|
|
|
def _Llama_sample(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
next_tokens = self.sampler(self.lm_head, hidden_states,
|
|
sampling_metadata)
|
|
return next_tokens
|
|
|
|
|
|
def _Qwen2_sample(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
if self.config.tie_word_embeddings:
|
|
lm_head_weight = self.model.embed_tokens
|
|
else:
|
|
lm_head_weight = self.lm_head
|
|
next_tokens = self.sampler(lm_head_weight, hidden_states,
|
|
sampling_metadata)
|
|
return next_tokens
|
|
|
|
|
|
def _Chatglm_sample(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
next_tokens = self.sampler(self.transformer.output_layer, hidden_states,
|
|
sampling_metadata)
|
|
|
|
return next_tokens
|
|
|
|
|
|
def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module,
|
|
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
logits = embedding(hidden_states)
|
|
if embedding_bias is not None:
|
|
logits += embedding_bias
|
|
logits = tensor_model_parallel_gather(logits)
|
|
# Remove paddings in vocab (if any).
|
|
if logits is not None:
|
|
logits = logits[:, :self.org_vocab_size]
|
|
return logits
|
|
|
|
|
|
def _MLP_forward(self, x):
|
|
gate_up = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
def _Attention_forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
input_metadata: InputMetadata,
|
|
) -> torch.Tensor:
|
|
qkv = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
k_cache, v_cache = kv_cache
|
|
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
|
output = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
def _QWen_Attention_forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
|
input_metadata: InputMetadata,
|
|
) -> torch.Tensor:
|
|
qkv = self.c_attn(hidden_states)
|
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
k_cache, v_cache = kv_cache
|
|
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
|
output = self.c_proj(attn_output)
|
|
return output
|
|
|
|
|
|
def _QWen_MLP_forward(self, x):
|
|
gate_up = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x = self.c_proj(x)
|
|
return x
|
|
|
|
|
|
def _ChatGLM_MLP_forward(self, hidden_states):
|
|
# [s, b, 4hp]
|
|
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
|
intermediate_parallel = self.activation_func(intermediate_parallel)
|
|
# [s, b, h]
|
|
output = self.dense_4h_to_h(intermediate_parallel)
|
|
return output
|
|
|
|
|
|
def _Baichuan_Attention_forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
|
input_metadata: InputMetadata,
|
|
) -> torch.Tensor:
|
|
qkv = self.W_pack(hidden_states)
|
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
|
if self.postion_embedding != "ALIBI":
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
k_cache, v_cache = kv_cache
|
|
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
|
output = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
def _ChatGLM_Attention_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
|
input_metadata: InputMetadata,
|
|
) -> torch.Tensor:
|
|
qkv = self.query_key_value(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
q, k = self.rotary_emb(position_ids, q, k)
|
|
key_cache, value_cache = kv_cache
|
|
context_layer = self.attn(
|
|
q,
|
|
k,
|
|
v,
|
|
key_cache,
|
|
value_cache,
|
|
input_metadata,
|
|
)
|
|
attn_output = self.dense(context_layer)
|
|
return attn_output
|
|
|
|
_REPLACED_MLP_LAYERS = {
|
|
LlamaMLP: _MLP_forward,
|
|
Qwen2MLP: _MLP_forward,
|
|
BaiChuanMLP: _MLP_forward,
|
|
QWenMLP: _QWen_MLP_forward,
|
|
GLMMLP: _ChatGLM_MLP_forward
|
|
}
|
|
|
|
_REPLACED_ATTENTION_LAYERS = {
|
|
LlamaAttention: _Attention_forward,
|
|
Qwen2Attention: _Attention_forward,
|
|
QWenAttention: _QWen_Attention_forward,
|
|
BaiChuanAttention: _Baichuan_Attention_forward,
|
|
GLMAttention: _ChatGLM_Attention_forward
|
|
}
|
|
|
|
_REPLACED_SAMPLER_LAYERS = {
|
|
LlamaForCausalLM: _Llama_sample,
|
|
QWenLMHeadModel: _Llama_sample,
|
|
ChatGLMForCausalLM: _Chatglm_sample,
|
|
Qwen2ForCausalLM: _Qwen2_sample,
|
|
BaiChuanBaseForCausalLM: _Llama_sample,
|
|
}
|
|
|
|
|
|
def _model_mlp_convert():
|
|
for module, replaced_func in _REPLACED_MLP_LAYERS.items():
|
|
setattr(module, "forward", replaced_func)
|
|
|
|
|
|
def _model_sample_convert():
|
|
setattr(Sampler, "_get_logits", _sample_get_logits)
|
|
for module, replaced_func in _REPLACED_SAMPLER_LAYERS.items():
|
|
setattr(module, "sample", replaced_func)
|
|
|
|
|
|
def _model_attention_convert():
|
|
for module, replaced_func in _REPLACED_ATTENTION_LAYERS.items():
|
|
setattr(module, "forward", replaced_func)
|
|
|
|
|
|
def _ipex_llm_convert(load_in_low_bit):
|
|
from vllm.worker.model_runner import ModelRunner
|
|
import vllm.model_executor.model_loader as model_loader
|
|
setattr(ModelRunner, "load_model", get_load_function(load_in_low_bit))
|
|
|
|
|
|
def get_load_function(low_bit):
|
|
def _ipex_llm_load_model(self) -> None:
|
|
# _model_mlp_convert()
|
|
# _model_attention_convert()
|
|
_model_sample_convert()
|
|
|
|
from vllm.utils import measure_device_memory
|
|
with measure_device_memory() as m:
|
|
# only support xpu for now
|
|
# We have to create a new DeviceConfig.
|
|
# Otherwise, will get the wrong xpu memory usage
|
|
self.model = get_model(self.model_config,
|
|
DeviceConfig("cpu"),
|
|
lora_config=self.lora_config,
|
|
parallel_config=self.parallel_config,
|
|
scheduler_config=self.scheduler_config)
|
|
from ipex_llm import optimize_model
|
|
optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype)
|
|
self.model = self.model.to(device=self.device_config.device,
|
|
dtype=self.model_config.dtype)
|
|
|
|
self.model_memory_usage = m.consumed_memory
|
|
logger = init_logger(__name__)
|
|
logger.info(f"Loading model weights took "
|
|
f"{self.model_memory_usage / float(2**30):.4f} GB")
|
|
|
|
if self.lora_config:
|
|
invalidInputError(hasattr(self.model, "supported_lora_modules")
|
|
and self.model.supported_lora_modules,
|
|
"Model does not support LoRA")
|
|
invalidInputError(hasattr(self.model, "embedding_modules"),
|
|
"Model does not have embedding_modules")
|
|
invalidInputError(hasattr(self.model, "embedding_padding_modules"),
|
|
"Model does not have embedding_padding_modules")
|
|
self.lora_manager = LRUCacheWorkerLoRAManager(
|
|
self.scheduler_config.max_num_seqs,
|
|
self.scheduler_config.max_num_batched_tokens +
|
|
self.scheduler_config.max_paddings, self.vocab_size,
|
|
self.lora_config, self.device, self.model.embedding_modules,
|
|
self.model.embedding_padding_modules)
|
|
self.model = self.lora_manager.create_lora_manager(self.model)
|
|
return _ipex_llm_load_model
|