fix vllm qwen2 models (#11879)
This commit is contained in:
parent
bd1e490d62
commit
537c0d2767
1 changed files with 12 additions and 4 deletions
|
|
@ -30,7 +30,7 @@ from vllm.config import DeviceConfig
|
|||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather
|
||||
|
||||
from typing import Tuple, Optional
|
||||
from typing import Tuple, Optional, Union
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
|
@ -51,8 +51,10 @@ def _Qwen2_sample(
|
|||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
if self.config.tie_word_embeddings:
|
||||
lm_head_weight = self.model.embed_tokens
|
||||
# Embedding layer is not optimized to LowBitLinear
|
||||
lm_head_weight = self.model.embed_tokens.weight
|
||||
else:
|
||||
# This layer is optimized to LowBitLinear
|
||||
lm_head_weight = self.lm_head
|
||||
next_tokens = self.sampler(lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
|
|
@ -70,8 +72,14 @@ def _Chatglm_sample(
|
|||
return next_tokens
|
||||
|
||||
|
||||
def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module,
|
||||
def _sample_get_logits(self, hidden_states: torch.Tensor,
|
||||
embedding: Union[torch.nn.Module, torch.Tensor],
|
||||
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
# For tie_word_embedding models, the embedding is not optimized as
|
||||
# the low_bit_linear layer...
|
||||
if isinstance(embedding, torch.Tensor):
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
else:
|
||||
logits = embedding(hidden_states)
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
|
|
|
|||
Loading…
Reference in a new issue