Implement selective batching for vLLM (#9659)

* add control to load hf model

* finish initial version of selective_batching

* temp

* finish

* Remove print statement

* fix error

* Apply yang's optimization

* a version that works

* We need to check kv_cache passed in, this could be an error. TODO: add fast decoding path

* format

* temp solution: not batching prefill requests

* a version that works for prefill batching

* format

* a solid version: works normally

* a temp version

* Solid version: remove redundant functions

* fix format

* format

* solid: add option to enable selective_batching

* remove logic for using transformer models

* format

* format

* solid: enable argument VLLM_ENABLE_SELECTIVE_BATCHING

* format

* finish

* format
This commit is contained in:
Guancheng Fu 2023-12-22 13:45:46 +08:00 committed by GitHub
parent 2f36769208
commit fdf93c9267
4 changed files with 467 additions and 36 deletions

View file

@ -46,6 +46,7 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from .utils import logger
from typing import Union
import numpy as np
import os
from bigdl.llm.utils.common import invalidInputError
@ -386,6 +387,8 @@ def convert_forward(m, target_m, new_forward):
def _optimize_post(model, lightweight_bmm=False):
from packaging import version
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31
from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
from bigdl.llm.transformers.models.llama import llama_mlp_forward
from transformers.modeling_utils import PreTrainedModel
@ -396,6 +399,10 @@ def _optimize_post(model, lightweight_bmm=False):
"supported for further optimizations")
return model
vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING")
enable_vllm_se_batching = vllm_selective_batching is not None
enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true"
trans_version = transformers.__version__
if version.parse(trans_version) >= version.parse("4.31.0"):
convert_forward(
@ -409,6 +416,17 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
transformers.models.llama.modeling_llama.LlamaMLP,
llama_mlp_forward)
if enable_vllm_se_batching:
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_selective_batching_forward_4_31,
)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_selective_batching_forward_4_31,
)
else:
# todo implement 4.28.0 ~ 4.30.2
pass

View file

@ -34,15 +34,17 @@
import torch
import importlib
import torch.nn as nn
from typing import Optional, Tuple
from typing import Optional, Tuple, Union, List
import math
import torch.nn.functional as F
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from transformers.modeling_outputs import BaseModelOutputWithPast
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -191,7 +193,6 @@ def llama_attention_forward_4_31(
value_states = [F.linear(hidden_states, value_slices[i])
for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
@ -305,6 +306,167 @@ def llama_attention_forward_4_31(
return attn_output.to(original_dtype), attn_weights, past_key_value
def llama_attention_selective_batching_forward_4_31(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
# TODO: consider this later - flash attention
# if not self.training and not hidden_states.requires_grad:
# fsdp_flag = check_flash_attention_available(hidden_states)
# else:
# fsdp_flag = False
# if fsdp_flag and q_len > 1:
# attention_dtype = torch.float16 # use fp16 for flash attention
# else:
# attention_dtype = original_dtype
attention_dtype = original_dtype
# TODO: decoding fast path
# use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
# enough_kv_room = is_enough_kv_cache_room(past_key_value[0])
# is_q4_0 = self.q_proj.qtype == SYM_INT4
# no_tp = not self.config.pretraining_tp > 1
# decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
# enough_kv_room and bsz * q_len == 1)
# single batch decoding fast path
# forward_qkv takes will perform QKV projection, rotary position embedding
# and save the key/value states to cache, then return query states and the
# extended key/value cache
# if decoding_fast_path:
# hidden_states = hidden_states.view(1, -1)
# kv_seq_len = past_key_value[0].shape[-2]
# cache_k = past_key_value[0]
# cache_v = past_key_value[1]
# import linear_q4_0
# query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
# self.q_proj.weight,
# self.k_proj.weight,
# self.v_proj.weight,
# position_ids,
# cache_k, cache_v,
# self.q_proj.weight.qtype,
# kv_seq_len,
# self.head_dim)
# kv_seq_len += 1
# else:
if self.config.pretraining_tp > 1:
invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet")
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value)
# TODO: fuse_rope
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")
updated_past_key_values = []
if past_key_value is not None:
batched_attention_output = []
# print(f"type of attention_mask is {type(attention_mask)}")
for batch in range(bsz):
past_k, past_v = past_key_value[batch]
current_kv_len = past_k.shape[-2] + 1
current_key_states = torch.cat([past_k,
key_states[batch: batch + 1, :, :, :]], dim=2)
current_value_states = torch.cat([past_v,
value_states[batch: batch + 1, :, :, :]], dim=2)
updated_past_key_values.append((current_key_states, current_value_states))
current_key_states = repeat_kv(current_key_states, self.num_key_value_groups)
current_value_states = repeat_kv(current_value_states, self.num_key_value_groups)
current_query_states = query_states[batch: batch + 1, :, :, :]
attn_output, attn_weights = native_sdp(current_query_states,
current_key_states,
current_value_states,
attention_mask[batch],
1,
1,
current_kv_len,
self.head_dim,
self.num_heads)
if attn_output.size() != (1, self.num_heads, 1, self.head_dim):
invalidInputError(False,
f"`attn_output` should be of size "
f"{(1, self.num_heads, 1, self.head_dim)}, but is"
f" {attn_output.size()}")
batched_attention_output.append(attn_output)
# For loop ends
# TODO: handle attention_weights later
attn_output = torch.concat(batched_attention_output, dim=0)
batched_attention_output.clear()
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(False,
f"`attn_output` should be of size "
f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, updated_past_key_values
# TODO: Assume always use_cache
# print(f"prefill with batch size {bsz}")
for batch in range(bsz):
updated_past_key_values.append((key_states[batch: batch + 1, :, :, :],
value_states[batch: batch+1, :, :, :]))
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
attn_output, attn_weights = native_sdp(query_states,
key_states,
value_states,
attention_mask,
bsz,
q_len,
kv_seq_len,
self.head_dim,
self.num_heads)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(False,
f"`attn_output` should be of size "
f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
def check_flash_attention_available(query):
# check whether ipex flash attention can be used
if query.device.type != "xpu":
@ -371,3 +533,171 @@ def native_sdp(query, key, value, attention_mask,
dtype=torch.float32).to(value.dtype)
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def llama_model_selective_batching_forward_4_31(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
if output_attentions is not None:
output_attentions = output_attentions
else:
output_attentions = self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
invalidInputError(False,
"You cannot specify both decoder_input_ids"
" and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
invalidInputError(False,
"You have to specify either "
"decoder_input_ids or decoder_inputs_embeds")
# seq_length_with_past = seq_length
past_key_values_length = 0
# The original position_ids in the format of [1, 1]
# However, this only applies when kv_len is the same for all the sequences
# We should set it to format of [batch, position_id]
# TODO: validate correctness
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
invalidInputError("vLLM: position_ids should never be None")
else:
# print(f"Original position_ids is {position_ids}")
position_ids = position_ids.view(-1, seq_length)
# print(f"after position_ids is {position_ids}")
# if past_key_values is None:
# # For prefill
# position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
# position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
# else:
# past_key_values_length = []
# for sequence_kv in past_key_values[0]:
# key = sequence_kv[0]
# past_key_values_length.append(key.shape[-2])
# position_ids = torch.tensor(past_key_values_length, dtype=torch.long, device=device)
# position_ids = position_ids.unsqueeze(0).view(-1, 1)
if past_key_values is not None:
# past_key_values in the format of num_layers x num_seqs x 2
# TODO: this may be incorrect
past_key_values_length = past_key_values[0][0][0].shape[2]
# seq_length_with_past = seq_length_with_past + past_key_values_length
# if position_ids is None:
# device = input_ids.device if input_ids is not None else inputs_embeds.device
# # [start, end)
# position_ids = torch.arange(
# past_key_values_length, seq_length +
# past_key_values_length, dtype=torch.long, device=device
# )
# position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
# else:
# position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
invalidInputError(False, "attention_mask should never be None")
# print(f"attention_mask before expanding: {attention_mask}")
if past_key_values is None:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
else:
i = 0
for attn_mask in attention_mask:
past_key_value_length = past_key_values[0][i][0].shape[2]
new_mask = self._prepare_decoder_attention_mask(
attn_mask, (1, seq_length), inputs_embeds, past_key_value_length
)
attention_mask[i] = new_mask
i += 1
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
invalidInputError(False, "gradient_checkpointing is not supported")
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) # noqa
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

View file

@ -27,6 +27,7 @@ from bigdl.llm.vllm.logger import init_logger
import math
import time
from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata
import os
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
@ -50,6 +51,10 @@ def _get_attention_mask_for_prompts(
]
return attention_mask
vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING")
enable_vllm_se_batching = vllm_selective_batching is not None
enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true"
class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
@ -61,12 +66,9 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
):
super().__init__(config, device, max_model_len)
self.config = config
# TODO(gc): later change this to a switch?
if True:
from bigdl.llm.transformers import AutoModelForCausalLM
from bigdl.llm import optimize_model
# low_bit = 'sym_int4'
# Always enable bigdl-llm model
from bigdl.llm.transformers import AutoModelForCausalLM
from bigdl.llm import optimize_model
if device == 'cpu':
model = AutoModelForCausalLM.from_pretrained(
config._name_or_path,
@ -81,7 +83,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
import intel_extension_for_pytorch as ipex
except ImportError:
print("Intel Extension for PyTorch is not installed, \
but is required for xpu inference.")
but is required for xpu inference.")
low_bit = 'sym_int4'
model = AutoModelForCausalLM.from_pretrained(
@ -93,17 +95,19 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
self.model = model.to('xpu')
self.sampler = BigDLSampler(config.vocab_size, device).to('xpu')
if device is None:
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
self.device = torch.device(device)
self.dtype = self.model.dtype
self.last_seq_ids = []
self.tmp_kv_cache = None
self.last_kv_cache = None
self.pad_token_id = config.pad_token_id
self.max_seq_limit = max_model_len
# GC: Note for selective batching
# KV_CACHE in the format of num_layers x 2 x (seq_id -> torch.Tensor)
# past_key_values in the format of num_layers x len(seq_id) x (2 x torch.Tensor)
# If we set num_layers to 9, have 10 sequences in total.
# then, for the kv_cache, we get 9 x 2 x 10 = 180 tensors
# for past_key_values, we get 9 x 10 x 2 = 180 tensors
def forward(
self,
seq_group_meta_data_lists: List[SequenceGroupMetadata],
@ -116,7 +120,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
decoder_kv_size = 2
bigdl_input_ids = []
bigdl_position_ids = []
# bigdl_position_ids = []
bigdl_attention_mask = []
cur_seq_ids = []
@ -144,8 +148,12 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
# 1. Assemble bigdl_input_ids end
if is_decoding_stage:
bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists,
kv_cache, num_layers, decoder_kv_size)
construct_kv_cache_func = self.get_construct_kv_cache_func(enable_vllm_se_batching)
bigdl_kv_cache = construct_kv_cache_func(cur_seq_ids,
seq_group_meta_data_lists,
kv_cache,
num_layers,
2)
else:
bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len)
bigdl_input_ids = [
@ -153,41 +161,72 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
for input_ids in bigdl_input_ids
]
decoding_attention_mask_list = []
decoding_position_ids = []
# num_layers x len(seq_id) x (2 x torch.Tensor)
if is_decoding_stage:
cur_seq_len = bigdl_kv_cache[0][0].size(2)
for seq_group_meta_data in seq_group_meta_data_lists:
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_id = seq_ids[0]
seq_data = seq_group_meta_data.seq_data[seq_id]
cur_pos = seq_data.get_len()
# bigdl_position_ids.append([cur_pos - 1])
cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
bigdl_attention_mask.append(cur_attention_mask)
if enable_vllm_se_batching:
batch = 0
for seq_group_meta_data in seq_group_meta_data_lists:
# Get current seq_len in kv_cache
current_seq_len = bigdl_kv_cache[0][batch][0].size(2)
batch += 1
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_data = seq_group_meta_data.seq_data[seq_ids[0]]
cur_pos = seq_data.get_len()
decoding_position_ids.append(cur_pos - 1)
# Total length: current_seq_len + 1
cur_attention_mask = [0] * (current_seq_len - cur_pos + 1) + [1] * (cur_pos)
decoding_attention_mask_list.append(cur_attention_mask)
else:
cur_seq_len = bigdl_kv_cache[0][0].size(2)
for seq_group_meta_data in seq_group_meta_data_lists:
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_id = seq_ids[0]
seq_data = seq_group_meta_data.seq_data[seq_id]
cur_pos = seq_data.get_len()
# bigdl_position_ids.append([cur_pos - 1])
# decoding_position_ids.append(cur_pos - 1)
cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
decoding_attention_mask_list.append(cur_attention_mask)
bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
if is_decoding_stage:
# bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
if enable_vllm_se_batching:
attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0)
for x in decoding_attention_mask_list]
position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1)
else:
attention_mask = torch.tensor(decoding_attention_mask_list, device=self.device)
position_ids = None
kwargs = {
"input_ids": bigdl_input_ids,
# "position_ids": bigdl_position_ids,
"attention_mask": bigdl_attention_mask,
"position_ids": position_ids,
"attention_mask": attention_mask,
"past_key_values": bigdl_kv_cache,
"use_cache": True,
# "return_dict": True,
}
else:
# Prefill stage
attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
if enable_vllm_se_batching:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
else:
position_ids = None
kwargs = {
"input_ids": bigdl_input_ids,
"attention_mask": torch.tensor(bigdl_attention_mask, device=self.device),
# "position_ids": bigdl_position_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": None,
"use_cache": True,
# "return_dict": True,
}
# Prefill may need additional space, which forces us to delete the last_kv_cache
if self.last_kv_cache:
del self.last_kv_cache
self.last_kv_cache = None
# pdb.set_trace()
if self.device.type == 'xpu':
@ -207,8 +246,12 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
# tmp = torch.xpu.memory_stats()
# logger.info(f"before: {tmp['allocated_bytes.all.current']}")
self.update_kv_cache(cur_seq_ids,
kv_cache, num_layers, decoder_kv_size)
if enable_vllm_se_batching:
self.update_kv_cache_selective_batching(
cur_seq_ids, kv_cache, num_layers, decoder_kv_size)
self.last_kv_cache = None
else:
self.update_kv_cache(cur_seq_ids, kv_cache, num_layers, decoder_kv_size)
# tmp = torch.xpu.memory_stats()
# logger.info(f"after: {tmp['allocated_bytes.all.current']}")

View file

@ -137,6 +137,34 @@ class BigDLModelForCausalLM(nn.Module):
return bigdl_kv_cache
def get_construct_kv_cache_func(self, enable_selective_batching):
if enable_selective_batching:
return self.prepare_kv_cache_selective_batching
else:
return self.prepare_kv_cache
# This is an implementation for models that KV Cache shape in (batch_size, num_heads,
# sequence_length, embed_size_per_head).
def prepare_kv_cache_selective_batching(
self,
cur_seq_ids: List[int],
seq_group_meta_data_lists: List[SequenceGroupMetadata],
kv_cache: Dict,
num_layers: int,
kv_cache_size_1: int,
):
# Return bigdl_kv_cache in the format of Tuple(List[Tuple(torch.Tensor)])
bigdl_kv_cache = []
for i in range(num_layers):
# Construct a list of tuple(tensor)
temp_cache = []
for seq_id in cur_seq_ids:
key = kv_cache[i][0][seq_id]
value = kv_cache[i][1][seq_id]
temp_cache.append((key, value))
bigdl_kv_cache.append(temp_cache)
return bigdl_kv_cache
# This is an implementation for models that KV Cache shape in (batch_size, num_heads,
# sequence_length, embed_size_per_head).
def update_kv_cache(
@ -153,6 +181,18 @@ class BigDLModelForCausalLM(nn.Module):
kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim]
batch_dim = batch_dim + 1
def update_kv_cache_selective_batching(
self,
cur_seq_ids: List[int],
kv_cache,
layer: int,
kv_cache_size_1: int,
) -> None:
for i in range(layer):
for j in range(len(cur_seq_ids)):
kv_cache[i][0][cur_seq_ids[j]] = self.last_kv_cache[i][j][0]
kv_cache[i][1][cur_seq_ids[j]] = self.last_kv_cache[i][j][1]
def forward(
self,
seq_group_meta_data_lists: List[SequenceGroupMetadata],