Enable ipex-llm optimization for lm head (#11589)

* basic

* Modify convert.py

* fix
This commit is contained in:
Guancheng Fu 2024-07-16 16:48:44 +08:00 committed by GitHub
parent 365adad59f
commit 06930ab258
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 83 additions and 7 deletions

View file

@ -149,9 +149,11 @@ def is_linear_module(module):
from vllm.model_executor.layers.linear import (
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
VLLM_LINEAR_LIST = [
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear,
MergedColumnParallelLinear,
ParallelLMHead
]
if is_module_in_classes(module, VLLM_LINEAR_LIST):
if 'xpu' in _VLLM_VERSION:
@ -167,6 +169,12 @@ def is_linear_module(module):
else:
# For vllm cpu
tp_size = 1
if isinstance(module, ParallelLMHead) and 'xpu' in _VLLM_VERSION:
in_features = module.embedding_dim
out_features = module.num_embeddings_per_partition
result = True
mp_group = None
return result, (in_features, out_features, mp_group)
in_features = module.input_size
out_features = module.output_size
result = True

View file

@ -15,18 +15,71 @@
#
import torch
from vllm.logger import init_logger
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention, LlamaForCausalLM
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention, Qwen2ForCausalLM
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention, QWenLMHeadModel
from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention
from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention, ChatGLMForCausalLM
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.layers.sampler import Sampler
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.input_metadata import InputMetadata
from vllm.config import DeviceConfig
from typing import Tuple
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather
from typing import Tuple, Optional
from ipex_llm.utils.common import invalidInputError
from vllm.sequence import SamplerOutput
def _Llama_sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head, hidden_states,
sampling_metadata)
return next_tokens
def _Qwen2_sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens
else:
lm_head_weight = self.lm_head
next_tokens = self.sampler(lm_head_weight, hidden_states,
sampling_metadata)
return next_tokens
def _Chatglm_sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.transformer.output_layer, hidden_states,
sampling_metadata)
return next_tokens
def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
logits = embedding(hidden_states)
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
def _MLP_forward(self, x):
@ -139,12 +192,26 @@ _REPLACED_ATTENTION_LAYERS = {
GLMAttention: _ChatGLM_Attention_forward
}
_REPLACED_SAMPLER_LAYERS = {
LlamaForCausalLM: _Llama_sample,
QWenLMHeadModel: _Llama_sample,
ChatGLMForCausalLM: _Chatglm_sample,
Qwen2ForCausalLM: _Qwen2_sample,
BaiChuanBaseForCausalLM: _Llama_sample,
}
def _model_mlp_convert():
for module, replaced_func in _REPLACED_MLP_LAYERS.items():
setattr(module, "forward", replaced_func)
def _model_sample_convert():
setattr(Sampler, "_get_logits", _sample_get_logits)
for module, replaced_func in _REPLACED_SAMPLER_LAYERS.items():
setattr(module, "sample", replaced_func)
def _model_attention_convert():
for module, replaced_func in _REPLACED_ATTENTION_LAYERS.items():
setattr(module, "forward", replaced_func)
@ -160,6 +227,7 @@ def get_load_function(low_bit):
def _ipex_llm_load_model(self) -> None:
_model_mlp_convert()
_model_attention_convert()
_model_sample_convert()
from vllm.utils import measure_device_memory
with measure_device_memory() as m: