fix vllm qwen2 models (#11879)

This commit is contained in:
Guancheng Fu 2024-08-21 11:05:24 +08:00 committed by GitHub
parent bd1e490d62
commit 537c0d2767
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -30,7 +30,7 @@ from vllm.config import DeviceConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather
from typing import Tuple, Optional
from typing import Tuple, Optional, Union
from ipex_llm.utils.common import invalidInputError
from vllm.sequence import SamplerOutput
@ -51,8 +51,10 @@ def _Qwen2_sample(
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens
# Embedding layer is not optimized to LowBitLinear
lm_head_weight = self.model.embed_tokens.weight
else:
# This layer is optimized to LowBitLinear
lm_head_weight = self.lm_head
next_tokens = self.sampler(lm_head_weight, hidden_states,
sampling_metadata)
@ -70,9 +72,15 @@ def _Chatglm_sample(
return next_tokens
def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module,
def _sample_get_logits(self, hidden_states: torch.Tensor,
embedding: Union[torch.nn.Module, torch.Tensor],
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
logits = embedding(hidden_states)
# For tie_word_embedding models, the embedding is not optimized as
# the low_bit_linear layer...
if isinstance(embedding, torch.Tensor):
logits = torch.matmul(hidden_states, embedding.t())
else:
logits = embedding(hidden_states)
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)