Enable ipex-llm optimization for lm head (#11589)
* basic * Modify convert.py * fix
This commit is contained in:
parent
365adad59f
commit
06930ab258
2 changed files with 83 additions and 7 deletions
|
|
@ -149,9 +149,11 @@ def is_linear_module(module):
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
|
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
VLLM_LINEAR_LIST = [
|
VLLM_LINEAR_LIST = [
|
||||||
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
|
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
ParallelLMHead
|
||||||
]
|
]
|
||||||
if is_module_in_classes(module, VLLM_LINEAR_LIST):
|
if is_module_in_classes(module, VLLM_LINEAR_LIST):
|
||||||
if 'xpu' in _VLLM_VERSION:
|
if 'xpu' in _VLLM_VERSION:
|
||||||
|
|
@ -167,6 +169,12 @@ def is_linear_module(module):
|
||||||
else:
|
else:
|
||||||
# For vllm cpu
|
# For vllm cpu
|
||||||
tp_size = 1
|
tp_size = 1
|
||||||
|
if isinstance(module, ParallelLMHead) and 'xpu' in _VLLM_VERSION:
|
||||||
|
in_features = module.embedding_dim
|
||||||
|
out_features = module.num_embeddings_per_partition
|
||||||
|
result = True
|
||||||
|
mp_group = None
|
||||||
|
return result, (in_features, out_features, mp_group)
|
||||||
in_features = module.input_size
|
in_features = module.input_size
|
||||||
out_features = module.output_size
|
out_features = module.output_size
|
||||||
result = True
|
result = True
|
||||||
|
|
|
||||||
|
|
@ -15,18 +15,71 @@
|
||||||
#
|
#
|
||||||
import torch
|
import torch
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention
|
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention, LlamaForCausalLM
|
||||||
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention
|
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention, Qwen2ForCausalLM
|
||||||
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention
|
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention, QWenLMHeadModel
|
||||||
from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention
|
from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention
|
||||||
from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention
|
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
||||||
|
from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention, ChatGLMForCausalLM
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
|
||||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.config import DeviceConfig
|
from vllm.config import DeviceConfig
|
||||||
from typing import Tuple
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather
|
||||||
|
|
||||||
|
from typing import Tuple, Optional
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
|
||||||
|
def _Llama_sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _Qwen2_sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
if self.config.tie_word_embeddings:
|
||||||
|
lm_head_weight = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
lm_head_weight = self.lm_head
|
||||||
|
next_tokens = self.sampler(lm_head_weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _Chatglm_sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(self.transformer.output_layer, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module,
|
||||||
|
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
logits = embedding(hidden_states)
|
||||||
|
if embedding_bias is not None:
|
||||||
|
logits += embedding_bias
|
||||||
|
logits = tensor_model_parallel_gather(logits)
|
||||||
|
# Remove paddings in vocab (if any).
|
||||||
|
if logits is not None:
|
||||||
|
logits = logits[:, :self.org_vocab_size]
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _MLP_forward(self, x):
|
def _MLP_forward(self, x):
|
||||||
|
|
@ -139,12 +192,26 @@ _REPLACED_ATTENTION_LAYERS = {
|
||||||
GLMAttention: _ChatGLM_Attention_forward
|
GLMAttention: _ChatGLM_Attention_forward
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_REPLACED_SAMPLER_LAYERS = {
|
||||||
|
LlamaForCausalLM: _Llama_sample,
|
||||||
|
QWenLMHeadModel: _Llama_sample,
|
||||||
|
ChatGLMForCausalLM: _Chatglm_sample,
|
||||||
|
Qwen2ForCausalLM: _Qwen2_sample,
|
||||||
|
BaiChuanBaseForCausalLM: _Llama_sample,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _model_mlp_convert():
|
def _model_mlp_convert():
|
||||||
for module, replaced_func in _REPLACED_MLP_LAYERS.items():
|
for module, replaced_func in _REPLACED_MLP_LAYERS.items():
|
||||||
setattr(module, "forward", replaced_func)
|
setattr(module, "forward", replaced_func)
|
||||||
|
|
||||||
|
|
||||||
|
def _model_sample_convert():
|
||||||
|
setattr(Sampler, "_get_logits", _sample_get_logits)
|
||||||
|
for module, replaced_func in _REPLACED_SAMPLER_LAYERS.items():
|
||||||
|
setattr(module, "sample", replaced_func)
|
||||||
|
|
||||||
|
|
||||||
def _model_attention_convert():
|
def _model_attention_convert():
|
||||||
for module, replaced_func in _REPLACED_ATTENTION_LAYERS.items():
|
for module, replaced_func in _REPLACED_ATTENTION_LAYERS.items():
|
||||||
setattr(module, "forward", replaced_func)
|
setattr(module, "forward", replaced_func)
|
||||||
|
|
@ -160,6 +227,7 @@ def get_load_function(low_bit):
|
||||||
def _ipex_llm_load_model(self) -> None:
|
def _ipex_llm_load_model(self) -> None:
|
||||||
_model_mlp_convert()
|
_model_mlp_convert()
|
||||||
_model_attention_convert()
|
_model_attention_convert()
|
||||||
|
_model_sample_convert()
|
||||||
|
|
||||||
from vllm.utils import measure_device_memory
|
from vllm.utils import measure_device_memory
|
||||||
with measure_device_memory() as m:
|
with measure_device_memory() as m:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue