fix vllm qwen2 models (#11879)
This commit is contained in:
parent
bd1e490d62
commit
537c0d2767
1 changed files with 12 additions and 4 deletions
|
|
@ -30,7 +30,7 @@ from vllm.config import DeviceConfig
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather
|
from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather
|
||||||
|
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional, Union
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
|
@ -51,8 +51,10 @@ def _Qwen2_sample(
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
lm_head_weight = self.model.embed_tokens
|
# Embedding layer is not optimized to LowBitLinear
|
||||||
|
lm_head_weight = self.model.embed_tokens.weight
|
||||||
else:
|
else:
|
||||||
|
# This layer is optimized to LowBitLinear
|
||||||
lm_head_weight = self.lm_head
|
lm_head_weight = self.lm_head
|
||||||
next_tokens = self.sampler(lm_head_weight, hidden_states,
|
next_tokens = self.sampler(lm_head_weight, hidden_states,
|
||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
|
|
@ -70,8 +72,14 @@ def _Chatglm_sample(
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
|
|
||||||
def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module,
|
def _sample_get_logits(self, hidden_states: torch.Tensor,
|
||||||
|
embedding: Union[torch.nn.Module, torch.Tensor],
|
||||||
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
# For tie_word_embedding models, the embedding is not optimized as
|
||||||
|
# the low_bit_linear layer...
|
||||||
|
if isinstance(embedding, torch.Tensor):
|
||||||
|
logits = torch.matmul(hidden_states, embedding.t())
|
||||||
|
else:
|
||||||
logits = embedding(hidden_states)
|
logits = embedding(hidden_states)
|
||||||
if embedding_bias is not None:
|
if embedding_bias is not None:
|
||||||
logits += embedding_bias
|
logits += embedding_bias
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue