Enable ipex-llm optimization for lm head (#11589)
* basic * Modify convert.py * fix
This commit is contained in:
		
							parent
							
								
									365adad59f
								
							
						
					
					
						commit
						06930ab258
					
				
					 2 changed files with 83 additions and 7 deletions
				
			
		| 
						 | 
					@ -149,9 +149,11 @@ def is_linear_module(module):
 | 
				
			||||||
        from vllm.model_executor.layers.linear import (
 | 
					        from vllm.model_executor.layers.linear import (
 | 
				
			||||||
            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
					            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
 | 
				
			||||||
        VLLM_LINEAR_LIST = [
 | 
					        VLLM_LINEAR_LIST = [
 | 
				
			||||||
            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
					            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear,
 | 
				
			||||||
 | 
					            MergedColumnParallelLinear,
 | 
				
			||||||
 | 
					            ParallelLMHead
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
        if is_module_in_classes(module, VLLM_LINEAR_LIST):
 | 
					        if is_module_in_classes(module, VLLM_LINEAR_LIST):
 | 
				
			||||||
            if 'xpu' in _VLLM_VERSION:
 | 
					            if 'xpu' in _VLLM_VERSION:
 | 
				
			||||||
| 
						 | 
					@ -167,6 +169,12 @@ def is_linear_module(module):
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                # For vllm cpu
 | 
					                # For vllm cpu
 | 
				
			||||||
                tp_size = 1
 | 
					                tp_size = 1
 | 
				
			||||||
 | 
					            if isinstance(module, ParallelLMHead) and 'xpu' in _VLLM_VERSION:
 | 
				
			||||||
 | 
					                in_features = module.embedding_dim
 | 
				
			||||||
 | 
					                out_features = module.num_embeddings_per_partition
 | 
				
			||||||
 | 
					                result = True
 | 
				
			||||||
 | 
					                mp_group = None
 | 
				
			||||||
 | 
					                return result, (in_features, out_features, mp_group)
 | 
				
			||||||
            in_features = module.input_size
 | 
					            in_features = module.input_size
 | 
				
			||||||
            out_features = module.output_size
 | 
					            out_features = module.output_size
 | 
				
			||||||
            result = True
 | 
					            result = True
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,18 +15,71 @@
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from vllm.logger import init_logger
 | 
					from vllm.logger import init_logger
 | 
				
			||||||
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention
 | 
					from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention, LlamaForCausalLM
 | 
				
			||||||
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention
 | 
					from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention, Qwen2ForCausalLM
 | 
				
			||||||
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention
 | 
					from vllm.model_executor.models.qwen import QWenMLP, QWenAttention, QWenLMHeadModel
 | 
				
			||||||
from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention
 | 
					from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention
 | 
				
			||||||
from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention
 | 
					from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
 | 
				
			||||||
 | 
					from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention, ChatGLMForCausalLM
 | 
				
			||||||
from vllm.model_executor.model_loader import get_model
 | 
					from vllm.model_executor.model_loader import get_model
 | 
				
			||||||
 | 
					from vllm.model_executor.layers.sampler import Sampler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
 | 
					from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
 | 
				
			||||||
from vllm.model_executor.input_metadata import InputMetadata
 | 
					from vllm.model_executor.input_metadata import InputMetadata
 | 
				
			||||||
from vllm.config import DeviceConfig
 | 
					from vllm.config import DeviceConfig
 | 
				
			||||||
from typing import Tuple
 | 
					from vllm.model_executor.sampling_metadata import SamplingMetadata
 | 
				
			||||||
 | 
					from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Tuple, Optional
 | 
				
			||||||
from ipex_llm.utils.common import invalidInputError
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					from vllm.sequence import SamplerOutput
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _Llama_sample(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					    sampling_metadata: SamplingMetadata,
 | 
				
			||||||
 | 
					) -> Optional[SamplerOutput]:
 | 
				
			||||||
 | 
					    next_tokens = self.sampler(self.lm_head, hidden_states,
 | 
				
			||||||
 | 
					                               sampling_metadata)
 | 
				
			||||||
 | 
					    return next_tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _Qwen2_sample(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					    sampling_metadata: SamplingMetadata,
 | 
				
			||||||
 | 
					) -> Optional[SamplerOutput]:
 | 
				
			||||||
 | 
					    if self.config.tie_word_embeddings:
 | 
				
			||||||
 | 
					        lm_head_weight = self.model.embed_tokens
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        lm_head_weight = self.lm_head
 | 
				
			||||||
 | 
					    next_tokens = self.sampler(lm_head_weight, hidden_states,
 | 
				
			||||||
 | 
					                               sampling_metadata)
 | 
				
			||||||
 | 
					    return next_tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _Chatglm_sample(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					    sampling_metadata: SamplingMetadata,
 | 
				
			||||||
 | 
					) -> Optional[SamplerOutput]:
 | 
				
			||||||
 | 
					    next_tokens = self.sampler(self.transformer.output_layer, hidden_states,
 | 
				
			||||||
 | 
					                               sampling_metadata)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return next_tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module,
 | 
				
			||||||
 | 
					                       embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
 | 
				
			||||||
 | 
					    logits = embedding(hidden_states)
 | 
				
			||||||
 | 
					    if embedding_bias is not None:
 | 
				
			||||||
 | 
					        logits += embedding_bias
 | 
				
			||||||
 | 
					    logits = tensor_model_parallel_gather(logits)
 | 
				
			||||||
 | 
					    # Remove paddings in vocab (if any).
 | 
				
			||||||
 | 
					    if logits is not None:
 | 
				
			||||||
 | 
					        logits = logits[:, :self.org_vocab_size]
 | 
				
			||||||
 | 
					    return logits
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _MLP_forward(self, x):
 | 
					def _MLP_forward(self, x):
 | 
				
			||||||
| 
						 | 
					@ -139,12 +192,26 @@ _REPLACED_ATTENTION_LAYERS = {
 | 
				
			||||||
    GLMAttention: _ChatGLM_Attention_forward
 | 
					    GLMAttention: _ChatGLM_Attention_forward
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_REPLACED_SAMPLER_LAYERS = {
 | 
				
			||||||
 | 
					    LlamaForCausalLM: _Llama_sample,
 | 
				
			||||||
 | 
					    QWenLMHeadModel: _Llama_sample,
 | 
				
			||||||
 | 
					    ChatGLMForCausalLM: _Chatglm_sample,
 | 
				
			||||||
 | 
					    Qwen2ForCausalLM: _Qwen2_sample,
 | 
				
			||||||
 | 
					    BaiChuanBaseForCausalLM: _Llama_sample,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _model_mlp_convert():
 | 
					def _model_mlp_convert():
 | 
				
			||||||
    for module, replaced_func in _REPLACED_MLP_LAYERS.items():
 | 
					    for module, replaced_func in _REPLACED_MLP_LAYERS.items():
 | 
				
			||||||
        setattr(module, "forward", replaced_func)
 | 
					        setattr(module, "forward", replaced_func)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _model_sample_convert():
 | 
				
			||||||
 | 
					    setattr(Sampler, "_get_logits", _sample_get_logits)
 | 
				
			||||||
 | 
					    for module, replaced_func in _REPLACED_SAMPLER_LAYERS.items():
 | 
				
			||||||
 | 
					        setattr(module, "sample", replaced_func)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _model_attention_convert():
 | 
					def _model_attention_convert():
 | 
				
			||||||
    for module, replaced_func in _REPLACED_ATTENTION_LAYERS.items():
 | 
					    for module, replaced_func in _REPLACED_ATTENTION_LAYERS.items():
 | 
				
			||||||
        setattr(module, "forward", replaced_func)
 | 
					        setattr(module, "forward", replaced_func)
 | 
				
			||||||
| 
						 | 
					@ -160,6 +227,7 @@ def get_load_function(low_bit):
 | 
				
			||||||
    def _ipex_llm_load_model(self) -> None:
 | 
					    def _ipex_llm_load_model(self) -> None:
 | 
				
			||||||
        _model_mlp_convert()
 | 
					        _model_mlp_convert()
 | 
				
			||||||
        _model_attention_convert()
 | 
					        _model_attention_convert()
 | 
				
			||||||
 | 
					        _model_sample_convert()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        from vllm.utils import measure_device_memory
 | 
					        from vllm.utils import measure_device_memory
 | 
				
			||||||
        with measure_device_memory() as m:
 | 
					        with measure_device_memory() as m:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue