diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 9fe3f3ba..ca0b9f86 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -46,6 +46,7 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype from .utils import logger from typing import Union import numpy as np +import os from bigdl.llm.utils.common import invalidInputError @@ -386,6 +387,8 @@ def convert_forward(m, target_m, new_forward): def _optimize_post(model, lightweight_bmm=False): from packaging import version from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 + from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31 + from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31 from bigdl.llm.transformers.models.llama import llama_rms_norm_forward from bigdl.llm.transformers.models.llama import llama_mlp_forward from transformers.modeling_utils import PreTrainedModel @@ -396,6 +399,10 @@ def _optimize_post(model, lightweight_bmm=False): "supported for further optimizations") return model + vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING") + enable_vllm_se_batching = vllm_selective_batching is not None + enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true" + trans_version = transformers.__version__ if version.parse(trans_version) >= version.parse("4.31.0"): convert_forward( @@ -409,6 +416,17 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, transformers.models.llama.modeling_llama.LlamaMLP, llama_mlp_forward) + if enable_vllm_se_batching: + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaModel, + llama_model_selective_batching_forward_4_31, + ) + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaAttention, + llama_attention_selective_batching_forward_4_31, + ) else: # todo implement 4.28.0 ~ 4.30.2 pass diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index e2c251df..d3630362 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -34,15 +34,17 @@ import torch import importlib import torch.nn as nn -from typing import Optional, Tuple +from typing import Optional, Tuple, Union, List import math import torch.nn.functional as F from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu +from transformers.modeling_outputs import BaseModelOutputWithPast from bigdl.llm.transformers.low_bit_linear import SYM_INT4 from bigdl.llm.ggml.quantize import ggml_tensor_qtype +from bigdl.llm.utils.common import invalidInputError def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -191,7 +193,6 @@ def llama_attention_forward_4_31( value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) - else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -305,6 +306,167 @@ def llama_attention_forward_4_31( return attn_output.to(original_dtype), attn_weights, past_key_value +def llama_attention_selective_batching_forward_4_31( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + device = hidden_states.device + # for flash attention + original_dtype = hidden_states.dtype + # TODO: consider this later - flash attention + # if not self.training and not hidden_states.requires_grad: + # fsdp_flag = check_flash_attention_available(hidden_states) + # else: + # fsdp_flag = False + # if fsdp_flag and q_len > 1: + # attention_dtype = torch.float16 # use fp16 for flash attention + # else: + # attention_dtype = original_dtype + + attention_dtype = original_dtype + + # TODO: decoding fast path + # use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + # enough_kv_room = is_enough_kv_cache_room(past_key_value[0]) + # is_q4_0 = self.q_proj.qtype == SYM_INT4 + # no_tp = not self.config.pretraining_tp > 1 + # decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and + # enough_kv_room and bsz * q_len == 1) + + # single batch decoding fast path + # forward_qkv takes will perform QKV projection, rotary position embedding + # and save the key/value states to cache, then return query states and the + # extended key/value cache + # if decoding_fast_path: + # hidden_states = hidden_states.view(1, -1) + # kv_seq_len = past_key_value[0].shape[-2] + # cache_k = past_key_value[0] + # cache_v = past_key_value[1] + # import linear_q4_0 + # query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states, + # self.q_proj.weight, + # self.k_proj.weight, + # self.v_proj.weight, + # position_ids, + # cache_k, cache_v, + # self.q_proj.weight.qtype, + # kv_seq_len, + # self.head_dim) + # kv_seq_len += 1 + + # else: + if self.config.pretraining_tp > 1: + invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet") + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, + self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, + self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value) + + # TODO: fuse_rope + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "llama") + + updated_past_key_values = [] + if past_key_value is not None: + batched_attention_output = [] + # print(f"type of attention_mask is {type(attention_mask)}") + for batch in range(bsz): + past_k, past_v = past_key_value[batch] + current_kv_len = past_k.shape[-2] + 1 + + current_key_states = torch.cat([past_k, + key_states[batch: batch + 1, :, :, :]], dim=2) + current_value_states = torch.cat([past_v, + value_states[batch: batch + 1, :, :, :]], dim=2) + + updated_past_key_values.append((current_key_states, current_value_states)) + + current_key_states = repeat_kv(current_key_states, self.num_key_value_groups) + current_value_states = repeat_kv(current_value_states, self.num_key_value_groups) + + current_query_states = query_states[batch: batch + 1, :, :, :] + attn_output, attn_weights = native_sdp(current_query_states, + current_key_states, + current_value_states, + attention_mask[batch], + 1, + 1, + current_kv_len, + self.head_dim, + self.num_heads) + if attn_output.size() != (1, self.num_heads, 1, self.head_dim): + invalidInputError(False, + f"`attn_output` should be of size " + f"{(1, self.num_heads, 1, self.head_dim)}, but is" + f" {attn_output.size()}") + batched_attention_output.append(attn_output) + # For loop ends + # TODO: handle attention_weights later + attn_output = torch.concat(batched_attention_output, dim=0) + batched_attention_output.clear() + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + invalidInputError(False, + f"`attn_output` should be of size " + f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output, None, updated_past_key_values + + # TODO: Assume always use_cache + # print(f"prefill with batch size {bsz}") + for batch in range(bsz): + updated_past_key_values.append((key_states[batch: batch + 1, :, :, :], + value_states[batch: batch+1, :, :, :])) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + attn_output, attn_weights = native_sdp(query_states, + key_states, + value_states, + attention_mask, + bsz, + q_len, + kv_seq_len, + self.head_dim, + self.num_heads) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + invalidInputError(False, + f"`attn_output` should be of size " + f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + return attn_output.to(original_dtype), attn_weights, updated_past_key_values + + def check_flash_attention_available(query): # check whether ipex flash attention can be used if query.device.type != "xpu": @@ -371,3 +533,171 @@ def native_sdp(query, key, value, attention_mask, dtype=torch.float32).to(value.dtype) attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights + + +def llama_model_selective_batching_forward_4_31( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions is not None: + output_attentions = output_attentions + else: + output_attentions = self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + invalidInputError(False, + "You cannot specify both decoder_input_ids" + " and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + invalidInputError(False, + "You have to specify either " + "decoder_input_ids or decoder_inputs_embeds") + + # seq_length_with_past = seq_length + past_key_values_length = 0 + + # The original position_ids in the format of [1, 1] + # However, this only applies when kv_len is the same for all the sequences + # We should set it to format of [batch, position_id] + # TODO: validate correctness + device = input_ids.device if input_ids is not None else inputs_embeds.device + if position_ids is None: + invalidInputError("vLLM: position_ids should never be None") + else: + # print(f"Original position_ids is {position_ids}") + position_ids = position_ids.view(-1, seq_length) + # print(f"after position_ids is {position_ids}") + # if past_key_values is None: + # # For prefill + # position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) + # position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + # else: + # past_key_values_length = [] + # for sequence_kv in past_key_values[0]: + # key = sequence_kv[0] + # past_key_values_length.append(key.shape[-2]) + # position_ids = torch.tensor(past_key_values_length, dtype=torch.long, device=device) + # position_ids = position_ids.unsqueeze(0).view(-1, 1) + + if past_key_values is not None: + # past_key_values in the format of num_layers x num_seqs x 2 + # TODO: this may be incorrect + past_key_values_length = past_key_values[0][0][0].shape[2] + # seq_length_with_past = seq_length_with_past + past_key_values_length + + # if position_ids is None: + # device = input_ids.device if input_ids is not None else inputs_embeds.device + # # [start, end) + # position_ids = torch.arange( + # past_key_values_length, seq_length + + # past_key_values_length, dtype=torch.long, device=device + # ) + # position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + # else: + # position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + invalidInputError(False, "attention_mask should never be None") + # print(f"attention_mask before expanding: {attention_mask}") + if past_key_values is None: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + else: + i = 0 + for attn_mask in attention_mask: + past_key_value_length = past_key_values[0][i][0].shape[2] + new_mask = self._prepare_decoder_attention_mask( + attn_mask, (1, seq_length), inputs_embeds, past_key_value_length + ) + attention_mask[i] = new_mask + i += 1 + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + invalidInputError(False, "gradient_checkpointing is not supported") + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) # noqa + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index cecf4df6..eb6fa282 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -27,6 +27,7 @@ from bigdl.llm.vllm.logger import init_logger import math import time from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata +import os from transformers.generation.logits_process import ( LogitsProcessorList, RepetitionPenaltyLogitsProcessor, @@ -50,6 +51,10 @@ def _get_attention_mask_for_prompts( ] return attention_mask +vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING") +enable_vllm_se_batching = vllm_selective_batching is not None +enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true" + class BigDLLlamaForCausalLM(BigDLModelForCausalLM): @@ -61,12 +66,9 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): ): super().__init__(config, device, max_model_len) self.config = config - # TODO(gc): later change this to a switch? - if True: - from bigdl.llm.transformers import AutoModelForCausalLM - from bigdl.llm import optimize_model - - # low_bit = 'sym_int4' + # Always enable bigdl-llm model + from bigdl.llm.transformers import AutoModelForCausalLM + from bigdl.llm import optimize_model if device == 'cpu': model = AutoModelForCausalLM.from_pretrained( config._name_or_path, @@ -81,7 +83,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): import intel_extension_for_pytorch as ipex except ImportError: print("Intel Extension for PyTorch is not installed, \ - but is required for xpu inference.") + but is required for xpu inference.") low_bit = 'sym_int4' model = AutoModelForCausalLM.from_pretrained( @@ -93,17 +95,19 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): self.model = model.to('xpu') self.sampler = BigDLSampler(config.vocab_size, device).to('xpu') - if device is None: - self.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu") - else: - self.device = torch.device(device) + self.device = torch.device(device) self.dtype = self.model.dtype self.last_seq_ids = [] - self.tmp_kv_cache = None + self.last_kv_cache = None self.pad_token_id = config.pad_token_id self.max_seq_limit = max_model_len + # GC: Note for selective batching + # KV_CACHE in the format of num_layers x 2 x (seq_id -> torch.Tensor) + # past_key_values in the format of num_layers x len(seq_id) x (2 x torch.Tensor) + # If we set num_layers to 9, have 10 sequences in total. + # then, for the kv_cache, we get 9 x 2 x 10 = 180 tensors + # for past_key_values, we get 9 x 10 x 2 = 180 tensors def forward( self, seq_group_meta_data_lists: List[SequenceGroupMetadata], @@ -116,7 +120,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): decoder_kv_size = 2 bigdl_input_ids = [] - bigdl_position_ids = [] + # bigdl_position_ids = [] bigdl_attention_mask = [] cur_seq_ids = [] @@ -144,8 +148,12 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): # 1. Assemble bigdl_input_ids end if is_decoding_stage: - bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists, - kv_cache, num_layers, decoder_kv_size) + construct_kv_cache_func = self.get_construct_kv_cache_func(enable_vllm_se_batching) + bigdl_kv_cache = construct_kv_cache_func(cur_seq_ids, + seq_group_meta_data_lists, + kv_cache, + num_layers, + 2) else: bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len) bigdl_input_ids = [ @@ -153,41 +161,72 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): for input_ids in bigdl_input_ids ] + decoding_attention_mask_list = [] + decoding_position_ids = [] + # num_layers x len(seq_id) x (2 x torch.Tensor) if is_decoding_stage: - cur_seq_len = bigdl_kv_cache[0][0].size(2) - for seq_group_meta_data in seq_group_meta_data_lists: - seq_ids = list(seq_group_meta_data.seq_data.keys()) - seq_id = seq_ids[0] - seq_data = seq_group_meta_data.seq_data[seq_id] - cur_pos = seq_data.get_len() - # bigdl_position_ids.append([cur_pos - 1]) - cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos) - bigdl_attention_mask.append(cur_attention_mask) + if enable_vllm_se_batching: + batch = 0 + for seq_group_meta_data in seq_group_meta_data_lists: + # Get current seq_len in kv_cache + current_seq_len = bigdl_kv_cache[0][batch][0].size(2) + batch += 1 + seq_ids = list(seq_group_meta_data.seq_data.keys()) + seq_data = seq_group_meta_data.seq_data[seq_ids[0]] + cur_pos = seq_data.get_len() + decoding_position_ids.append(cur_pos - 1) + # Total length: current_seq_len + 1 + cur_attention_mask = [0] * (current_seq_len - cur_pos + 1) + [1] * (cur_pos) + decoding_attention_mask_list.append(cur_attention_mask) + else: + cur_seq_len = bigdl_kv_cache[0][0].size(2) + for seq_group_meta_data in seq_group_meta_data_lists: + seq_ids = list(seq_group_meta_data.seq_data.keys()) + seq_id = seq_ids[0] + seq_data = seq_group_meta_data.seq_data[seq_id] + cur_pos = seq_data.get_len() + # bigdl_position_ids.append([cur_pos - 1]) + # decoding_position_ids.append(cur_pos - 1) + cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos) + decoding_attention_mask_list.append(cur_attention_mask) bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device) if is_decoding_stage: - # bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device) - bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device) + if enable_vllm_se_batching: + attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0) + for x in decoding_attention_mask_list] + position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1) + else: + attention_mask = torch.tensor(decoding_attention_mask_list, device=self.device) + position_ids = None kwargs = { "input_ids": bigdl_input_ids, - # "position_ids": bigdl_position_ids, - "attention_mask": bigdl_attention_mask, + "position_ids": position_ids, + "attention_mask": attention_mask, "past_key_values": bigdl_kv_cache, "use_cache": True, # "return_dict": True, } else: + # Prefill stage + attention_mask = torch.tensor(bigdl_attention_mask, device=self.device) + if enable_vllm_se_batching: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + position_ids = None kwargs = { "input_ids": bigdl_input_ids, - "attention_mask": torch.tensor(bigdl_attention_mask, device=self.device), - # "position_ids": bigdl_position_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, "past_key_values": None, "use_cache": True, # "return_dict": True, } + # Prefill may need additional space, which forces us to delete the last_kv_cache if self.last_kv_cache: - del self.last_kv_cache + self.last_kv_cache = None # pdb.set_trace() if self.device.type == 'xpu': @@ -207,8 +246,12 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM): # tmp = torch.xpu.memory_stats() # logger.info(f"before: {tmp['allocated_bytes.all.current']}") - self.update_kv_cache(cur_seq_ids, - kv_cache, num_layers, decoder_kv_size) + if enable_vllm_se_batching: + self.update_kv_cache_selective_batching( + cur_seq_ids, kv_cache, num_layers, decoder_kv_size) + self.last_kv_cache = None + else: + self.update_kv_cache(cur_seq_ids, kv_cache, num_layers, decoder_kv_size) # tmp = torch.xpu.memory_stats() # logger.info(f"after: {tmp['allocated_bytes.all.current']}") diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py index 6694a3f1..a81993dc 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py @@ -137,6 +137,34 @@ class BigDLModelForCausalLM(nn.Module): return bigdl_kv_cache + def get_construct_kv_cache_func(self, enable_selective_batching): + if enable_selective_batching: + return self.prepare_kv_cache_selective_batching + else: + return self.prepare_kv_cache + + # This is an implementation for models that KV Cache shape in (batch_size, num_heads, + # sequence_length, embed_size_per_head). + def prepare_kv_cache_selective_batching( + self, + cur_seq_ids: List[int], + seq_group_meta_data_lists: List[SequenceGroupMetadata], + kv_cache: Dict, + num_layers: int, + kv_cache_size_1: int, + ): + # Return bigdl_kv_cache in the format of Tuple(List[Tuple(torch.Tensor)]) + bigdl_kv_cache = [] + for i in range(num_layers): + # Construct a list of tuple(tensor) + temp_cache = [] + for seq_id in cur_seq_ids: + key = kv_cache[i][0][seq_id] + value = kv_cache[i][1][seq_id] + temp_cache.append((key, value)) + bigdl_kv_cache.append(temp_cache) + return bigdl_kv_cache + # This is an implementation for models that KV Cache shape in (batch_size, num_heads, # sequence_length, embed_size_per_head). def update_kv_cache( @@ -153,6 +181,18 @@ class BigDLModelForCausalLM(nn.Module): kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim] batch_dim = batch_dim + 1 + def update_kv_cache_selective_batching( + self, + cur_seq_ids: List[int], + kv_cache, + layer: int, + kv_cache_size_1: int, + ) -> None: + for i in range(layer): + for j in range(len(cur_seq_ids)): + kv_cache[i][0][cur_seq_ids[j]] = self.last_kv_cache[i][j][0] + kv_cache[i][1][cur_seq_ids[j]] = self.last_kv_cache[i][j][1] + def forward( self, seq_group_meta_data_lists: List[SequenceGroupMetadata],