diff --git a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py index 79b56c3b..94494da2 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py @@ -30,7 +30,7 @@ from vllm.config import DeviceConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather -from typing import Tuple, Optional +from typing import Tuple, Optional, Union from ipex_llm.utils.common import invalidInputError from vllm.sequence import SamplerOutput @@ -51,8 +51,10 @@ def _Qwen2_sample( sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: if self.config.tie_word_embeddings: - lm_head_weight = self.model.embed_tokens + # Embedding layer is not optimized to LowBitLinear + lm_head_weight = self.model.embed_tokens.weight else: + # This layer is optimized to LowBitLinear lm_head_weight = self.lm_head next_tokens = self.sampler(lm_head_weight, hidden_states, sampling_metadata) @@ -70,9 +72,15 @@ def _Chatglm_sample( return next_tokens -def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module, +def _sample_get_logits(self, hidden_states: torch.Tensor, + embedding: Union[torch.nn.Module, torch.Tensor], embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: - logits = embedding(hidden_states) + # For tie_word_embedding models, the embedding is not optimized as + # the low_bit_linear layer... + if isinstance(embedding, torch.Tensor): + logits = torch.matmul(hidden_states, embedding.t()) + else: + logits = embedding(hidden_states) if embedding_bias is not None: logits += embedding_bias logits = tensor_model_parallel_gather(logits)