[LLM] vLLM: Support Mixtral Model (#9670)
Add Mixtral support for BigDL vLLM.
This commit is contained in:
		
							parent
							
								
									dc5b1d7e9d
								
							
						
					
					
						commit
						1c6499e880
					
				
					 2 changed files with 222 additions and 0 deletions
				
			
		| 
						 | 
					@ -40,6 +40,8 @@ from transformers import PretrainedConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from bigdl.llm.vllm.config import ModelConfig
 | 
					from bigdl.llm.vllm.config import ModelConfig
 | 
				
			||||||
from bigdl.llm.vllm.model_executor.models.bigdl_llama import BigDLLlamaForCausalLM
 | 
					from bigdl.llm.vllm.model_executor.models.bigdl_llama import BigDLLlamaForCausalLM
 | 
				
			||||||
 | 
					from bigdl.llm.vllm.model_executor.models.bigdl_mixtral import BigDLMixtralForCausalLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# bigdl-llm Intel specified code change
 | 
					# bigdl-llm Intel specified code change
 | 
				
			||||||
| 
						 | 
					@ -61,6 +63,7 @@ _MODEL_REGISTRY = {
 | 
				
			||||||
    "LlamaForCausalLM": BigDLLlamaForCausalLM,
 | 
					    "LlamaForCausalLM": BigDLLlamaForCausalLM,
 | 
				
			||||||
    # "LLaMAForCausalLM": LlamaForCausalLM,  # For decapoda-research/llama-*
 | 
					    # "LLaMAForCausalLM": LlamaForCausalLM,  # For decapoda-research/llama-*
 | 
				
			||||||
    # "MistralForCausalLM": MistralForCausalLM,
 | 
					    # "MistralForCausalLM": MistralForCausalLM,
 | 
				
			||||||
 | 
					    "MixtralForCausalLM": BigDLMixtralForCausalLM,
 | 
				
			||||||
    # "MPTForCausalLM": MPTForCausalLM,
 | 
					    # "MPTForCausalLM": MPTForCausalLM,
 | 
				
			||||||
    # "OPTForCausalLM": OPTForCausalLM,
 | 
					    # "OPTForCausalLM": OPTForCausalLM,
 | 
				
			||||||
    # "QWenLMHeadModel": QWenLMHeadModel,
 | 
					    # "QWenLMHeadModel": QWenLMHeadModel,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,219 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from transformers import AutoTokenizer, PreTrainedTokenizerBase, LlamaConfig
 | 
				
			||||||
 | 
					from typing import Optional, Tuple, List, Type, Dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from bigdl.llm.vllm.sequence import SequenceOutputs, SequenceGroupMetadata
 | 
				
			||||||
 | 
					from bigdl.llm.vllm.model_executor.layers.bigdl_sampler import BigDLSampler
 | 
				
			||||||
 | 
					from bigdl.llm.vllm.model_executor.models.bigdl_model import BigDLModelForCausalLM
 | 
				
			||||||
 | 
					from bigdl.llm.vllm.logger import init_logger
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata
 | 
				
			||||||
 | 
					from transformers.generation.logits_process import (
 | 
				
			||||||
 | 
					    LogitsProcessorList,
 | 
				
			||||||
 | 
					    RepetitionPenaltyLogitsProcessor,
 | 
				
			||||||
 | 
					    TemperatureLogitsWarper,
 | 
				
			||||||
 | 
					    TopKLogitsWarper,
 | 
				
			||||||
 | 
					    TopPLogitsWarper,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					logger = init_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _pad_to_max(x: List[int], max_len: int, padding_id: int = 0) -> List[int]:
 | 
				
			||||||
 | 
					    return [padding_id] * (max_len - len(x)) + x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_attention_mask_for_prompts(
 | 
				
			||||||
 | 
					    input_ids: List[List[int]], max_prompt_len: int
 | 
				
			||||||
 | 
					) -> List[List[int]]:
 | 
				
			||||||
 | 
					    attention_mask = [
 | 
				
			||||||
 | 
					        [0] * (max_prompt_len - len(prompt)) + [1] * len(prompt) for prompt in input_ids
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    return attention_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BigDLMixtralForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        config,
 | 
				
			||||||
 | 
					        device: Optional[str] = None,
 | 
				
			||||||
 | 
					        max_model_len: Optional[int] = None,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__(config, device, max_model_len)
 | 
				
			||||||
 | 
					        self.config = config
 | 
				
			||||||
 | 
					        # TODO(gc): later change this to a switch?
 | 
				
			||||||
 | 
					        if True:
 | 
				
			||||||
 | 
					            from bigdl.llm.transformers import AutoModelForCausalLM
 | 
				
			||||||
 | 
					            from bigdl.llm import optimize_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # low_bit = 'sym_int4'
 | 
				
			||||||
 | 
					        if device == 'cpu':
 | 
				
			||||||
 | 
					            model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
 | 
					                config._name_or_path,
 | 
				
			||||||
 | 
					                low_cpu_mem_usage=True,
 | 
				
			||||||
 | 
					                trust_remote_code=True,
 | 
				
			||||||
 | 
					                use_cache=True,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.model = optimize_model(model)
 | 
				
			||||||
 | 
					            self.sampler = BigDLSampler(config.vocab_size, device)
 | 
				
			||||||
 | 
					        elif device == 'xpu':
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                import intel_extension_for_pytorch as ipex
 | 
				
			||||||
 | 
					            except ImportError:
 | 
				
			||||||
 | 
					                print("Intel Extension for PyTorch is not installed, \
 | 
				
			||||||
 | 
					                       but is required for xpu inference.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            low_bit = 'sym_int4'
 | 
				
			||||||
 | 
					            model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
 | 
					                config._name_or_path,
 | 
				
			||||||
 | 
					                load_in_low_bit=low_bit,
 | 
				
			||||||
 | 
					                trust_remote_code=True,
 | 
				
			||||||
 | 
					                optimize_model=True,
 | 
				
			||||||
 | 
					                use_cache=True,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.model = model.to('xpu')
 | 
				
			||||||
 | 
					            self.sampler = BigDLSampler(config.vocab_size, device).to('xpu')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if device is None:
 | 
				
			||||||
 | 
					            self.device = torch.device(
 | 
				
			||||||
 | 
					                "cuda" if torch.cuda.is_available() else "cpu")
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.device = torch.device(device)
 | 
				
			||||||
 | 
					        self.dtype = self.model.dtype
 | 
				
			||||||
 | 
					        self.last_seq_ids = []
 | 
				
			||||||
 | 
					        self.tmp_kv_cache = None
 | 
				
			||||||
 | 
					        if config.pad_token_id is None:
 | 
				
			||||||
 | 
					            self.pad_token_id = config.eos_token_id
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.pad_token_id = config.pad_token_id
 | 
				
			||||||
 | 
					        self.max_seq_limit = max_model_len
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        seq_group_meta_data_lists: List[SequenceGroupMetadata],
 | 
				
			||||||
 | 
					        # kv_cache in the format [[dict() for _ in range(2)] for _ in range(32)]
 | 
				
			||||||
 | 
					        kv_cache: Optional[List[List[Dict]]] = None,
 | 
				
			||||||
 | 
					        input_metadata: Optional[InputMetadata] = None,
 | 
				
			||||||
 | 
					    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
 | 
				
			||||||
 | 
					        num_layers = self.model.config.num_hidden_layers
 | 
				
			||||||
 | 
					        # One for key, one for value
 | 
				
			||||||
 | 
					        decoder_kv_size = 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        bigdl_input_ids = []
 | 
				
			||||||
 | 
					        bigdl_position_ids = []
 | 
				
			||||||
 | 
					        bigdl_attention_mask = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cur_seq_ids = []
 | 
				
			||||||
 | 
					        max_prompt_len = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 0. Verify is_prompt or is_decoding
 | 
				
			||||||
 | 
					        is_decoding_stage = not seq_group_meta_data_lists[0].is_prompt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 1. Assemble bigdl_input_ids
 | 
				
			||||||
 | 
					        for seq_group_meta_data in seq_group_meta_data_lists:
 | 
				
			||||||
 | 
					            # req_id = seq_group_meta_data.request_id
 | 
				
			||||||
 | 
					            # is_decoding_stage = is_decoding_stage and (not seq_group_meta_data.is_prompt)
 | 
				
			||||||
 | 
					            seq_ids = list(seq_group_meta_data.seq_data.keys())
 | 
				
			||||||
 | 
					            seq_id = seq_ids[0]
 | 
				
			||||||
 | 
					            cur_seq_ids.append(seq_id)
 | 
				
			||||||
 | 
					            seq_data = seq_group_meta_data.seq_data[seq_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            cur_seq_input_ids = seq_data.get_token_ids()
 | 
				
			||||||
 | 
					            # context_len = seq_data.get_len()
 | 
				
			||||||
 | 
					            if seq_group_meta_data.is_prompt:
 | 
				
			||||||
 | 
					                bigdl_input_ids.append(cur_seq_input_ids)
 | 
				
			||||||
 | 
					                max_prompt_len = max(max_prompt_len, seq_data.get_len())
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                bigdl_input_ids.append([cur_seq_input_ids[-1]])
 | 
				
			||||||
 | 
					        # 1. Assemble bigdl_input_ids end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if is_decoding_stage:
 | 
				
			||||||
 | 
					            bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists,
 | 
				
			||||||
 | 
					                                                   kv_cache, num_layers, decoder_kv_size)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len)
 | 
				
			||||||
 | 
					            bigdl_input_ids = [
 | 
				
			||||||
 | 
					                _pad_to_max(input_ids, max_prompt_len, self.pad_token_id)
 | 
				
			||||||
 | 
					                for input_ids in bigdl_input_ids
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if is_decoding_stage:
 | 
				
			||||||
 | 
					            cur_seq_len = bigdl_kv_cache[0][0].size(2)
 | 
				
			||||||
 | 
					            for seq_group_meta_data in seq_group_meta_data_lists:
 | 
				
			||||||
 | 
					                seq_ids = list(seq_group_meta_data.seq_data.keys())
 | 
				
			||||||
 | 
					                seq_id = seq_ids[0]
 | 
				
			||||||
 | 
					                seq_data = seq_group_meta_data.seq_data[seq_id]
 | 
				
			||||||
 | 
					                cur_pos = seq_data.get_len()
 | 
				
			||||||
 | 
					                # bigdl_position_ids.append([cur_pos - 1])
 | 
				
			||||||
 | 
					                cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
 | 
				
			||||||
 | 
					                bigdl_attention_mask.append(cur_attention_mask)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if is_decoding_stage:
 | 
				
			||||||
 | 
					            # bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
 | 
				
			||||||
 | 
					            bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
 | 
				
			||||||
 | 
					            kwargs = {
 | 
				
			||||||
 | 
					                "input_ids": bigdl_input_ids,
 | 
				
			||||||
 | 
					                # "position_ids": bigdl_position_ids,
 | 
				
			||||||
 | 
					                "attention_mask": bigdl_attention_mask,
 | 
				
			||||||
 | 
					                "past_key_values": bigdl_kv_cache,
 | 
				
			||||||
 | 
					                "use_cache": True,
 | 
				
			||||||
 | 
					                # "return_dict": True,
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            kwargs = {
 | 
				
			||||||
 | 
					                "input_ids": bigdl_input_ids,
 | 
				
			||||||
 | 
					                "attention_mask": torch.tensor(bigdl_attention_mask, device=self.device),
 | 
				
			||||||
 | 
					                # "position_ids": bigdl_position_ids,
 | 
				
			||||||
 | 
					                "past_key_values": None,
 | 
				
			||||||
 | 
					                "use_cache": True,
 | 
				
			||||||
 | 
					                # "return_dict": True,
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            if self.last_kv_cache:
 | 
				
			||||||
 | 
					                self.last_kv_cache = None
 | 
				
			||||||
 | 
					        # pdb.set_trace()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.device.type == 'xpu':
 | 
				
			||||||
 | 
					            torch.xpu.empty_cache()
 | 
				
			||||||
 | 
					        st_timestamp = time.perf_counter()
 | 
				
			||||||
 | 
					        outputs = self.model.forward(**kwargs)
 | 
				
			||||||
 | 
					        # tmp = torch.xpu.memory_stats()
 | 
				
			||||||
 | 
					        # logger.info(f"0: {tmp['allocated_bytes.all.current']}")
 | 
				
			||||||
 | 
					        # self.last_seq_ids = cur_seq_ids[:]
 | 
				
			||||||
 | 
					        # self.last_kv_cache = outputs.past_key_values
 | 
				
			||||||
 | 
					        self._set_last_seq_ids(cur_seq_ids[:])
 | 
				
			||||||
 | 
					        self._set_last_kv_cache(outputs.past_key_values)
 | 
				
			||||||
 | 
					        # pdb.set_trace()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        logits = outputs.logits[:, -1, :]
 | 
				
			||||||
 | 
					        bigdl_output = self.sampler(logits, input_metadata, st_timestamp)
 | 
				
			||||||
 | 
					        # tmp = torch.xpu.memory_stats()
 | 
				
			||||||
 | 
					        # logger.info(f"before: {tmp['allocated_bytes.all.current']}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.update_kv_cache(cur_seq_ids,
 | 
				
			||||||
 | 
					                             kv_cache, num_layers, decoder_kv_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # tmp = torch.xpu.memory_stats()
 | 
				
			||||||
 | 
					        # logger.info(f"after: {tmp['allocated_bytes.all.current']}")
 | 
				
			||||||
 | 
					        return bigdl_output
 | 
				
			||||||
		Loading…
	
		Reference in a new issue