fix vllm qwen2 models (#11879)
This commit is contained in:
		
							parent
							
								
									bd1e490d62
								
							
						
					
					
						commit
						537c0d2767
					
				
					 1 changed files with 12 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -30,7 +30,7 @@ from vllm.config import DeviceConfig
 | 
			
		|||
from vllm.model_executor.sampling_metadata import SamplingMetadata
 | 
			
		||||
from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather
 | 
			
		||||
 | 
			
		||||
from typing import Tuple, Optional
 | 
			
		||||
from typing import Tuple, Optional, Union
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from vllm.sequence import SamplerOutput
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -51,8 +51,10 @@ def _Qwen2_sample(
 | 
			
		|||
    sampling_metadata: SamplingMetadata,
 | 
			
		||||
) -> Optional[SamplerOutput]:
 | 
			
		||||
    if self.config.tie_word_embeddings:
 | 
			
		||||
        lm_head_weight = self.model.embed_tokens
 | 
			
		||||
        # Embedding layer is not optimized to LowBitLinear
 | 
			
		||||
        lm_head_weight = self.model.embed_tokens.weight
 | 
			
		||||
    else:
 | 
			
		||||
        # This layer is optimized to LowBitLinear
 | 
			
		||||
        lm_head_weight = self.lm_head
 | 
			
		||||
    next_tokens = self.sampler(lm_head_weight, hidden_states,
 | 
			
		||||
                               sampling_metadata)
 | 
			
		||||
| 
						 | 
				
			
			@ -70,9 +72,15 @@ def _Chatglm_sample(
 | 
			
		|||
    return next_tokens
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module,
 | 
			
		||||
def _sample_get_logits(self, hidden_states: torch.Tensor,
 | 
			
		||||
                       embedding: Union[torch.nn.Module, torch.Tensor],
 | 
			
		||||
                       embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
 | 
			
		||||
    logits = embedding(hidden_states)
 | 
			
		||||
    # For tie_word_embedding models, the embedding is not optimized as
 | 
			
		||||
    # the low_bit_linear layer...
 | 
			
		||||
    if isinstance(embedding, torch.Tensor):
 | 
			
		||||
        logits = torch.matmul(hidden_states, embedding.t())
 | 
			
		||||
    else:
 | 
			
		||||
        logits = embedding(hidden_states)
 | 
			
		||||
    if embedding_bias is not None:
 | 
			
		||||
        logits += embedding_bias
 | 
			
		||||
    logits = tensor_model_parallel_gather(logits)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue