Implement selective batching for vLLM (#9659)

* add control to load hf model

* finish initial version of selective_batching

* temp

* finish

* Remove print statement

* fix error

* Apply yang's optimization

* a version that works

* We need to check kv_cache passed in, this could be an error. TODO: add fast decoding path

* format

* temp solution: not batching prefill requests

* a version that works for prefill batching

* format

* a solid version: works normally

* a temp version

* Solid version: remove redundant functions

* fix format

* format

* solid: add option to enable selective_batching

* remove logic for using transformer models

* format

* format

* solid: enable argument VLLM_ENABLE_SELECTIVE_BATCHING

* format

* finish

* format
This commit is contained in:
Guancheng Fu 2023-12-22 13:45:46 +08:00 committed by GitHub
parent 2f36769208
commit fdf93c9267
4 changed files with 467 additions and 36 deletions

View file

@ -46,6 +46,7 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from .utils import logger from .utils import logger
from typing import Union from typing import Union
import numpy as np import numpy as np
import os
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
@ -386,6 +387,8 @@ def convert_forward(m, target_m, new_forward):
def _optimize_post(model, lightweight_bmm=False): def _optimize_post(model, lightweight_bmm=False):
from packaging import version from packaging import version
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31
from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
from bigdl.llm.transformers.models.llama import llama_mlp_forward from bigdl.llm.transformers.models.llama import llama_mlp_forward
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
@ -396,6 +399,10 @@ def _optimize_post(model, lightweight_bmm=False):
"supported for further optimizations") "supported for further optimizations")
return model return model
vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING")
enable_vllm_se_batching = vllm_selective_batching is not None
enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true"
trans_version = transformers.__version__ trans_version = transformers.__version__
if version.parse(trans_version) >= version.parse("4.31.0"): if version.parse(trans_version) >= version.parse("4.31.0"):
convert_forward( convert_forward(
@ -409,6 +416,17 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, convert_forward(model,
transformers.models.llama.modeling_llama.LlamaMLP, transformers.models.llama.modeling_llama.LlamaMLP,
llama_mlp_forward) llama_mlp_forward)
if enable_vllm_se_batching:
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_selective_batching_forward_4_31,
)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_selective_batching_forward_4_31,
)
else: else:
# todo implement 4.28.0 ~ 4.30.2 # todo implement 4.28.0 ~ 4.30.2
pass pass

View file

@ -34,15 +34,17 @@
import torch import torch
import importlib import importlib
import torch.nn as nn import torch.nn as nn
from typing import Optional, Tuple from typing import Optional, Tuple, Union, List
import math import math
import torch.nn.functional as F import torch.nn.functional as F
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from transformers.modeling_outputs import BaseModelOutputWithPast
from bigdl.llm.transformers.low_bit_linear import SYM_INT4 from bigdl.llm.transformers.low_bit_linear import SYM_INT4
from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -191,7 +193,6 @@ def llama_attention_forward_4_31(
value_states = [F.linear(hidden_states, value_slices[i]) value_states = [F.linear(hidden_states, value_slices[i])
for i in range(self.config.pretraining_tp)] for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1) value_states = torch.cat(value_states, dim=-1)
else: else:
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
@ -305,6 +306,167 @@ def llama_attention_forward_4_31(
return attn_output.to(original_dtype), attn_weights, past_key_value return attn_output.to(original_dtype), attn_weights, past_key_value
def llama_attention_selective_batching_forward_4_31(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
# TODO: consider this later - flash attention
# if not self.training and not hidden_states.requires_grad:
# fsdp_flag = check_flash_attention_available(hidden_states)
# else:
# fsdp_flag = False
# if fsdp_flag and q_len > 1:
# attention_dtype = torch.float16 # use fp16 for flash attention
# else:
# attention_dtype = original_dtype
attention_dtype = original_dtype
# TODO: decoding fast path
# use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
# enough_kv_room = is_enough_kv_cache_room(past_key_value[0])
# is_q4_0 = self.q_proj.qtype == SYM_INT4
# no_tp = not self.config.pretraining_tp > 1
# decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
# enough_kv_room and bsz * q_len == 1)
# single batch decoding fast path
# forward_qkv takes will perform QKV projection, rotary position embedding
# and save the key/value states to cache, then return query states and the
# extended key/value cache
# if decoding_fast_path:
# hidden_states = hidden_states.view(1, -1)
# kv_seq_len = past_key_value[0].shape[-2]
# cache_k = past_key_value[0]
# cache_v = past_key_value[1]
# import linear_q4_0
# query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
# self.q_proj.weight,
# self.k_proj.weight,
# self.v_proj.weight,
# position_ids,
# cache_k, cache_v,
# self.q_proj.weight.qtype,
# kv_seq_len,
# self.head_dim)
# kv_seq_len += 1
# else:
if self.config.pretraining_tp > 1:
invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet")
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value)
# TODO: fuse_rope
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")
updated_past_key_values = []
if past_key_value is not None:
batched_attention_output = []
# print(f"type of attention_mask is {type(attention_mask)}")
for batch in range(bsz):
past_k, past_v = past_key_value[batch]
current_kv_len = past_k.shape[-2] + 1
current_key_states = torch.cat([past_k,
key_states[batch: batch + 1, :, :, :]], dim=2)
current_value_states = torch.cat([past_v,
value_states[batch: batch + 1, :, :, :]], dim=2)
updated_past_key_values.append((current_key_states, current_value_states))
current_key_states = repeat_kv(current_key_states, self.num_key_value_groups)
current_value_states = repeat_kv(current_value_states, self.num_key_value_groups)
current_query_states = query_states[batch: batch + 1, :, :, :]
attn_output, attn_weights = native_sdp(current_query_states,
current_key_states,
current_value_states,
attention_mask[batch],
1,
1,
current_kv_len,
self.head_dim,
self.num_heads)
if attn_output.size() != (1, self.num_heads, 1, self.head_dim):
invalidInputError(False,
f"`attn_output` should be of size "
f"{(1, self.num_heads, 1, self.head_dim)}, but is"
f" {attn_output.size()}")
batched_attention_output.append(attn_output)
# For loop ends
# TODO: handle attention_weights later
attn_output = torch.concat(batched_attention_output, dim=0)
batched_attention_output.clear()
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(False,
f"`attn_output` should be of size "
f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, updated_past_key_values
# TODO: Assume always use_cache
# print(f"prefill with batch size {bsz}")
for batch in range(bsz):
updated_past_key_values.append((key_states[batch: batch + 1, :, :, :],
value_states[batch: batch+1, :, :, :]))
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
attn_output, attn_weights = native_sdp(query_states,
key_states,
value_states,
attention_mask,
bsz,
q_len,
kv_seq_len,
self.head_dim,
self.num_heads)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(False,
f"`attn_output` should be of size "
f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
def check_flash_attention_available(query): def check_flash_attention_available(query):
# check whether ipex flash attention can be used # check whether ipex flash attention can be used
if query.device.type != "xpu": if query.device.type != "xpu":
@ -371,3 +533,171 @@ def native_sdp(query, key, value, attention_mask,
dtype=torch.float32).to(value.dtype) dtype=torch.float32).to(value.dtype)
attn_output = torch.matmul(attn_weights, value) attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights return attn_output, attn_weights
def llama_model_selective_batching_forward_4_31(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
if output_attentions is not None:
output_attentions = output_attentions
else:
output_attentions = self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
invalidInputError(False,
"You cannot specify both decoder_input_ids"
" and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
invalidInputError(False,
"You have to specify either "
"decoder_input_ids or decoder_inputs_embeds")
# seq_length_with_past = seq_length
past_key_values_length = 0
# The original position_ids in the format of [1, 1]
# However, this only applies when kv_len is the same for all the sequences
# We should set it to format of [batch, position_id]
# TODO: validate correctness
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
invalidInputError("vLLM: position_ids should never be None")
else:
# print(f"Original position_ids is {position_ids}")
position_ids = position_ids.view(-1, seq_length)
# print(f"after position_ids is {position_ids}")
# if past_key_values is None:
# # For prefill
# position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
# position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
# else:
# past_key_values_length = []
# for sequence_kv in past_key_values[0]:
# key = sequence_kv[0]
# past_key_values_length.append(key.shape[-2])
# position_ids = torch.tensor(past_key_values_length, dtype=torch.long, device=device)
# position_ids = position_ids.unsqueeze(0).view(-1, 1)
if past_key_values is not None:
# past_key_values in the format of num_layers x num_seqs x 2
# TODO: this may be incorrect
past_key_values_length = past_key_values[0][0][0].shape[2]
# seq_length_with_past = seq_length_with_past + past_key_values_length
# if position_ids is None:
# device = input_ids.device if input_ids is not None else inputs_embeds.device
# # [start, end)
# position_ids = torch.arange(
# past_key_values_length, seq_length +
# past_key_values_length, dtype=torch.long, device=device
# )
# position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
# else:
# position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
invalidInputError(False, "attention_mask should never be None")
# print(f"attention_mask before expanding: {attention_mask}")
if past_key_values is None:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
else:
i = 0
for attn_mask in attention_mask:
past_key_value_length = past_key_values[0][i][0].shape[2]
new_mask = self._prepare_decoder_attention_mask(
attn_mask, (1, seq_length), inputs_embeds, past_key_value_length
)
attention_mask[i] = new_mask
i += 1
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
invalidInputError(False, "gradient_checkpointing is not supported")
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) # noqa
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

View file

@ -27,6 +27,7 @@ from bigdl.llm.vllm.logger import init_logger
import math import math
import time import time
from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata
import os
from transformers.generation.logits_process import ( from transformers.generation.logits_process import (
LogitsProcessorList, LogitsProcessorList,
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,
@ -50,6 +51,10 @@ def _get_attention_mask_for_prompts(
] ]
return attention_mask return attention_mask
vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING")
enable_vllm_se_batching = vllm_selective_batching is not None
enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true"
class BigDLLlamaForCausalLM(BigDLModelForCausalLM): class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
@ -61,12 +66,9 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
): ):
super().__init__(config, device, max_model_len) super().__init__(config, device, max_model_len)
self.config = config self.config = config
# TODO(gc): later change this to a switch? # Always enable bigdl-llm model
if True: from bigdl.llm.transformers import AutoModelForCausalLM
from bigdl.llm.transformers import AutoModelForCausalLM from bigdl.llm import optimize_model
from bigdl.llm import optimize_model
# low_bit = 'sym_int4'
if device == 'cpu': if device == 'cpu':
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
config._name_or_path, config._name_or_path,
@ -81,7 +83,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
except ImportError: except ImportError:
print("Intel Extension for PyTorch is not installed, \ print("Intel Extension for PyTorch is not installed, \
but is required for xpu inference.") but is required for xpu inference.")
low_bit = 'sym_int4' low_bit = 'sym_int4'
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
@ -93,17 +95,19 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
self.model = model.to('xpu') self.model = model.to('xpu')
self.sampler = BigDLSampler(config.vocab_size, device).to('xpu') self.sampler = BigDLSampler(config.vocab_size, device).to('xpu')
if device is None: self.device = torch.device(device)
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
self.dtype = self.model.dtype self.dtype = self.model.dtype
self.last_seq_ids = [] self.last_seq_ids = []
self.tmp_kv_cache = None self.last_kv_cache = None
self.pad_token_id = config.pad_token_id self.pad_token_id = config.pad_token_id
self.max_seq_limit = max_model_len self.max_seq_limit = max_model_len
# GC: Note for selective batching
# KV_CACHE in the format of num_layers x 2 x (seq_id -> torch.Tensor)
# past_key_values in the format of num_layers x len(seq_id) x (2 x torch.Tensor)
# If we set num_layers to 9, have 10 sequences in total.
# then, for the kv_cache, we get 9 x 2 x 10 = 180 tensors
# for past_key_values, we get 9 x 10 x 2 = 180 tensors
def forward( def forward(
self, self,
seq_group_meta_data_lists: List[SequenceGroupMetadata], seq_group_meta_data_lists: List[SequenceGroupMetadata],
@ -116,7 +120,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
decoder_kv_size = 2 decoder_kv_size = 2
bigdl_input_ids = [] bigdl_input_ids = []
bigdl_position_ids = [] # bigdl_position_ids = []
bigdl_attention_mask = [] bigdl_attention_mask = []
cur_seq_ids = [] cur_seq_ids = []
@ -144,8 +148,12 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
# 1. Assemble bigdl_input_ids end # 1. Assemble bigdl_input_ids end
if is_decoding_stage: if is_decoding_stage:
bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists, construct_kv_cache_func = self.get_construct_kv_cache_func(enable_vllm_se_batching)
kv_cache, num_layers, decoder_kv_size) bigdl_kv_cache = construct_kv_cache_func(cur_seq_ids,
seq_group_meta_data_lists,
kv_cache,
num_layers,
2)
else: else:
bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len) bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len)
bigdl_input_ids = [ bigdl_input_ids = [
@ -153,41 +161,72 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
for input_ids in bigdl_input_ids for input_ids in bigdl_input_ids
] ]
decoding_attention_mask_list = []
decoding_position_ids = []
# num_layers x len(seq_id) x (2 x torch.Tensor)
if is_decoding_stage: if is_decoding_stage:
cur_seq_len = bigdl_kv_cache[0][0].size(2) if enable_vllm_se_batching:
for seq_group_meta_data in seq_group_meta_data_lists: batch = 0
seq_ids = list(seq_group_meta_data.seq_data.keys()) for seq_group_meta_data in seq_group_meta_data_lists:
seq_id = seq_ids[0] # Get current seq_len in kv_cache
seq_data = seq_group_meta_data.seq_data[seq_id] current_seq_len = bigdl_kv_cache[0][batch][0].size(2)
cur_pos = seq_data.get_len() batch += 1
# bigdl_position_ids.append([cur_pos - 1]) seq_ids = list(seq_group_meta_data.seq_data.keys())
cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos) seq_data = seq_group_meta_data.seq_data[seq_ids[0]]
bigdl_attention_mask.append(cur_attention_mask) cur_pos = seq_data.get_len()
decoding_position_ids.append(cur_pos - 1)
# Total length: current_seq_len + 1
cur_attention_mask = [0] * (current_seq_len - cur_pos + 1) + [1] * (cur_pos)
decoding_attention_mask_list.append(cur_attention_mask)
else:
cur_seq_len = bigdl_kv_cache[0][0].size(2)
for seq_group_meta_data in seq_group_meta_data_lists:
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_id = seq_ids[0]
seq_data = seq_group_meta_data.seq_data[seq_id]
cur_pos = seq_data.get_len()
# bigdl_position_ids.append([cur_pos - 1])
# decoding_position_ids.append(cur_pos - 1)
cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
decoding_attention_mask_list.append(cur_attention_mask)
bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device) bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
if is_decoding_stage: if is_decoding_stage:
# bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device) if enable_vllm_se_batching:
bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device) attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0)
for x in decoding_attention_mask_list]
position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1)
else:
attention_mask = torch.tensor(decoding_attention_mask_list, device=self.device)
position_ids = None
kwargs = { kwargs = {
"input_ids": bigdl_input_ids, "input_ids": bigdl_input_ids,
# "position_ids": bigdl_position_ids, "position_ids": position_ids,
"attention_mask": bigdl_attention_mask, "attention_mask": attention_mask,
"past_key_values": bigdl_kv_cache, "past_key_values": bigdl_kv_cache,
"use_cache": True, "use_cache": True,
# "return_dict": True, # "return_dict": True,
} }
else: else:
# Prefill stage
attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
if enable_vllm_se_batching:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
else:
position_ids = None
kwargs = { kwargs = {
"input_ids": bigdl_input_ids, "input_ids": bigdl_input_ids,
"attention_mask": torch.tensor(bigdl_attention_mask, device=self.device), "attention_mask": attention_mask,
# "position_ids": bigdl_position_ids, "position_ids": position_ids,
"past_key_values": None, "past_key_values": None,
"use_cache": True, "use_cache": True,
# "return_dict": True, # "return_dict": True,
} }
# Prefill may need additional space, which forces us to delete the last_kv_cache
if self.last_kv_cache: if self.last_kv_cache:
del self.last_kv_cache self.last_kv_cache = None
# pdb.set_trace() # pdb.set_trace()
if self.device.type == 'xpu': if self.device.type == 'xpu':
@ -207,8 +246,12 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
# tmp = torch.xpu.memory_stats() # tmp = torch.xpu.memory_stats()
# logger.info(f"before: {tmp['allocated_bytes.all.current']}") # logger.info(f"before: {tmp['allocated_bytes.all.current']}")
self.update_kv_cache(cur_seq_ids, if enable_vllm_se_batching:
kv_cache, num_layers, decoder_kv_size) self.update_kv_cache_selective_batching(
cur_seq_ids, kv_cache, num_layers, decoder_kv_size)
self.last_kv_cache = None
else:
self.update_kv_cache(cur_seq_ids, kv_cache, num_layers, decoder_kv_size)
# tmp = torch.xpu.memory_stats() # tmp = torch.xpu.memory_stats()
# logger.info(f"after: {tmp['allocated_bytes.all.current']}") # logger.info(f"after: {tmp['allocated_bytes.all.current']}")

View file

@ -137,6 +137,34 @@ class BigDLModelForCausalLM(nn.Module):
return bigdl_kv_cache return bigdl_kv_cache
def get_construct_kv_cache_func(self, enable_selective_batching):
if enable_selective_batching:
return self.prepare_kv_cache_selective_batching
else:
return self.prepare_kv_cache
# This is an implementation for models that KV Cache shape in (batch_size, num_heads,
# sequence_length, embed_size_per_head).
def prepare_kv_cache_selective_batching(
self,
cur_seq_ids: List[int],
seq_group_meta_data_lists: List[SequenceGroupMetadata],
kv_cache: Dict,
num_layers: int,
kv_cache_size_1: int,
):
# Return bigdl_kv_cache in the format of Tuple(List[Tuple(torch.Tensor)])
bigdl_kv_cache = []
for i in range(num_layers):
# Construct a list of tuple(tensor)
temp_cache = []
for seq_id in cur_seq_ids:
key = kv_cache[i][0][seq_id]
value = kv_cache[i][1][seq_id]
temp_cache.append((key, value))
bigdl_kv_cache.append(temp_cache)
return bigdl_kv_cache
# This is an implementation for models that KV Cache shape in (batch_size, num_heads, # This is an implementation for models that KV Cache shape in (batch_size, num_heads,
# sequence_length, embed_size_per_head). # sequence_length, embed_size_per_head).
def update_kv_cache( def update_kv_cache(
@ -153,6 +181,18 @@ class BigDLModelForCausalLM(nn.Module):
kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim] kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim]
batch_dim = batch_dim + 1 batch_dim = batch_dim + 1
def update_kv_cache_selective_batching(
self,
cur_seq_ids: List[int],
kv_cache,
layer: int,
kv_cache_size_1: int,
) -> None:
for i in range(layer):
for j in range(len(cur_seq_ids)):
kv_cache[i][0][cur_seq_ids[j]] = self.last_kv_cache[i][j][0]
kv_cache[i][1][cur_seq_ids[j]] = self.last_kv_cache[i][j][1]
def forward( def forward(
self, self,
seq_group_meta_data_lists: List[SequenceGroupMetadata], seq_group_meta_data_lists: List[SequenceGroupMetadata],