Implement selective batching for vLLM (#9659)
* add control to load hf model * finish initial version of selective_batching * temp * finish * Remove print statement * fix error * Apply yang's optimization * a version that works * We need to check kv_cache passed in, this could be an error. TODO: add fast decoding path * format * temp solution: not batching prefill requests * a version that works for prefill batching * format * a solid version: works normally * a temp version * Solid version: remove redundant functions * fix format * format * solid: add option to enable selective_batching * remove logic for using transformer models * format * format * solid: enable argument VLLM_ENABLE_SELECTIVE_BATCHING * format * finish * format
This commit is contained in:
parent
2f36769208
commit
fdf93c9267
4 changed files with 467 additions and 36 deletions
|
|
@ -46,6 +46,7 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
|||
from .utils import logger
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
import os
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
|
||||
|
|
@ -386,6 +387,8 @@ def convert_forward(m, target_m, new_forward):
|
|||
def _optimize_post(model, lightweight_bmm=False):
|
||||
from packaging import version
|
||||
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
||||
from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31
|
||||
from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31
|
||||
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
|
||||
from bigdl.llm.transformers.models.llama import llama_mlp_forward
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
|
@ -396,6 +399,10 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
"supported for further optimizations")
|
||||
return model
|
||||
|
||||
vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING")
|
||||
enable_vllm_se_batching = vllm_selective_batching is not None
|
||||
enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true"
|
||||
|
||||
trans_version = transformers.__version__
|
||||
if version.parse(trans_version) >= version.parse("4.31.0"):
|
||||
convert_forward(
|
||||
|
|
@ -409,6 +416,17 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model,
|
||||
transformers.models.llama.modeling_llama.LlamaMLP,
|
||||
llama_mlp_forward)
|
||||
if enable_vllm_se_batching:
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaModel,
|
||||
llama_model_selective_batching_forward_4_31,
|
||||
)
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||
llama_attention_selective_batching_forward_4_31,
|
||||
)
|
||||
else:
|
||||
# todo implement 4.28.0 ~ 4.30.2
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -34,15 +34,17 @@
|
|||
import torch
|
||||
import importlib
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union, List
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
|
|
@ -191,7 +193,6 @@ def llama_attention_forward_4_31(
|
|||
value_states = [F.linear(hidden_states, value_slices[i])
|
||||
for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
|
|
@ -305,6 +306,167 @@ def llama_attention_forward_4_31(
|
|||
return attn_output.to(original_dtype), attn_weights, past_key_value
|
||||
|
||||
|
||||
def llama_attention_selective_batching_forward_4_31(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
# for flash attention
|
||||
original_dtype = hidden_states.dtype
|
||||
# TODO: consider this later - flash attention
|
||||
# if not self.training and not hidden_states.requires_grad:
|
||||
# fsdp_flag = check_flash_attention_available(hidden_states)
|
||||
# else:
|
||||
# fsdp_flag = False
|
||||
# if fsdp_flag and q_len > 1:
|
||||
# attention_dtype = torch.float16 # use fp16 for flash attention
|
||||
# else:
|
||||
# attention_dtype = original_dtype
|
||||
|
||||
attention_dtype = original_dtype
|
||||
|
||||
# TODO: decoding fast path
|
||||
# use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
# enough_kv_room = is_enough_kv_cache_room(past_key_value[0])
|
||||
# is_q4_0 = self.q_proj.qtype == SYM_INT4
|
||||
# no_tp = not self.config.pretraining_tp > 1
|
||||
# decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
|
||||
# enough_kv_room and bsz * q_len == 1)
|
||||
|
||||
# single batch decoding fast path
|
||||
# forward_qkv takes will perform QKV projection, rotary position embedding
|
||||
# and save the key/value states to cache, then return query states and the
|
||||
# extended key/value cache
|
||||
# if decoding_fast_path:
|
||||
# hidden_states = hidden_states.view(1, -1)
|
||||
# kv_seq_len = past_key_value[0].shape[-2]
|
||||
# cache_k = past_key_value[0]
|
||||
# cache_v = past_key_value[1]
|
||||
# import linear_q4_0
|
||||
# query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
|
||||
# self.q_proj.weight,
|
||||
# self.k_proj.weight,
|
||||
# self.v_proj.weight,
|
||||
# position_ids,
|
||||
# cache_k, cache_v,
|
||||
# self.q_proj.weight.qtype,
|
||||
# kv_seq_len,
|
||||
# self.head_dim)
|
||||
# kv_seq_len += 1
|
||||
|
||||
# else:
|
||||
if self.config.pretraining_tp > 1:
|
||||
invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet")
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len,
|
||||
self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len,
|
||||
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len,
|
||||
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value)
|
||||
|
||||
# TODO: fuse_rope
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids, "llama")
|
||||
|
||||
updated_past_key_values = []
|
||||
if past_key_value is not None:
|
||||
batched_attention_output = []
|
||||
# print(f"type of attention_mask is {type(attention_mask)}")
|
||||
for batch in range(bsz):
|
||||
past_k, past_v = past_key_value[batch]
|
||||
current_kv_len = past_k.shape[-2] + 1
|
||||
|
||||
current_key_states = torch.cat([past_k,
|
||||
key_states[batch: batch + 1, :, :, :]], dim=2)
|
||||
current_value_states = torch.cat([past_v,
|
||||
value_states[batch: batch + 1, :, :, :]], dim=2)
|
||||
|
||||
updated_past_key_values.append((current_key_states, current_value_states))
|
||||
|
||||
current_key_states = repeat_kv(current_key_states, self.num_key_value_groups)
|
||||
current_value_states = repeat_kv(current_value_states, self.num_key_value_groups)
|
||||
|
||||
current_query_states = query_states[batch: batch + 1, :, :, :]
|
||||
attn_output, attn_weights = native_sdp(current_query_states,
|
||||
current_key_states,
|
||||
current_value_states,
|
||||
attention_mask[batch],
|
||||
1,
|
||||
1,
|
||||
current_kv_len,
|
||||
self.head_dim,
|
||||
self.num_heads)
|
||||
if attn_output.size() != (1, self.num_heads, 1, self.head_dim):
|
||||
invalidInputError(False,
|
||||
f"`attn_output` should be of size "
|
||||
f"{(1, self.num_heads, 1, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}")
|
||||
batched_attention_output.append(attn_output)
|
||||
# For loop ends
|
||||
# TODO: handle attention_weights later
|
||||
attn_output = torch.concat(batched_attention_output, dim=0)
|
||||
batched_attention_output.clear()
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
invalidInputError(False,
|
||||
f"`attn_output` should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}")
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, None, updated_past_key_values
|
||||
|
||||
# TODO: Assume always use_cache
|
||||
# print(f"prefill with batch size {bsz}")
|
||||
for batch in range(bsz):
|
||||
updated_past_key_values.append((key_states[batch: batch + 1, :, :, :],
|
||||
value_states[batch: batch+1, :, :, :]))
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||
dtype=attention_dtype)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||
dtype=attention_dtype)
|
||||
attn_output, attn_weights = native_sdp(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
bsz,
|
||||
q_len,
|
||||
kv_seq_len,
|
||||
self.head_dim,
|
||||
self.num_heads)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
invalidInputError(False,
|
||||
f"`attn_output` should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}")
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
|
||||
|
||||
|
||||
def check_flash_attention_available(query):
|
||||
# check whether ipex flash attention can be used
|
||||
if query.device.type != "xpu":
|
||||
|
|
@ -371,3 +533,171 @@ def native_sdp(query, key, value, attention_mask,
|
|||
dtype=torch.float32).to(value.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def llama_model_selective_batching_forward_4_31(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
if output_attentions is not None:
|
||||
output_attentions = output_attentions
|
||||
else:
|
||||
output_attentions = self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
invalidInputError(False,
|
||||
"You cannot specify both decoder_input_ids"
|
||||
" and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
invalidInputError(False,
|
||||
"You have to specify either "
|
||||
"decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
# seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
# The original position_ids in the format of [1, 1]
|
||||
# However, this only applies when kv_len is the same for all the sequences
|
||||
# We should set it to format of [batch, position_id]
|
||||
# TODO: validate correctness
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if position_ids is None:
|
||||
invalidInputError("vLLM: position_ids should never be None")
|
||||
else:
|
||||
# print(f"Original position_ids is {position_ids}")
|
||||
position_ids = position_ids.view(-1, seq_length)
|
||||
# print(f"after position_ids is {position_ids}")
|
||||
# if past_key_values is None:
|
||||
# # For prefill
|
||||
# position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
|
||||
# position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
# else:
|
||||
# past_key_values_length = []
|
||||
# for sequence_kv in past_key_values[0]:
|
||||
# key = sequence_kv[0]
|
||||
# past_key_values_length.append(key.shape[-2])
|
||||
# position_ids = torch.tensor(past_key_values_length, dtype=torch.long, device=device)
|
||||
# position_ids = position_ids.unsqueeze(0).view(-1, 1)
|
||||
|
||||
if past_key_values is not None:
|
||||
# past_key_values in the format of num_layers x num_seqs x 2
|
||||
# TODO: this may be incorrect
|
||||
past_key_values_length = past_key_values[0][0][0].shape[2]
|
||||
# seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
# if position_ids is None:
|
||||
# device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
# # [start, end)
|
||||
# position_ids = torch.arange(
|
||||
# past_key_values_length, seq_length +
|
||||
# past_key_values_length, dtype=torch.long, device=device
|
||||
# )
|
||||
# position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
# else:
|
||||
# position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
invalidInputError(False, "attention_mask should never be None")
|
||||
# print(f"attention_mask before expanding: {attention_mask}")
|
||||
if past_key_values is None:
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
else:
|
||||
i = 0
|
||||
for attn_mask in attention_mask:
|
||||
past_key_value_length = past_key_values[0][i][0].shape[2]
|
||||
new_mask = self._prepare_decoder_attention_mask(
|
||||
attn_mask, (1, seq_length), inputs_embeds, past_key_value_length
|
||||
)
|
||||
attention_mask[i] = new_mask
|
||||
i += 1
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
invalidInputError(False, "gradient_checkpointing is not supported")
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) # noqa
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from bigdl.llm.vllm.logger import init_logger
|
|||
import math
|
||||
import time
|
||||
from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata
|
||||
import os
|
||||
from transformers.generation.logits_process import (
|
||||
LogitsProcessorList,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
|
|
@ -50,6 +51,10 @@ def _get_attention_mask_for_prompts(
|
|||
]
|
||||
return attention_mask
|
||||
|
||||
vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING")
|
||||
enable_vllm_se_batching = vllm_selective_batching is not None
|
||||
enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true"
|
||||
|
||||
|
||||
class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||
|
||||
|
|
@ -61,12 +66,9 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
):
|
||||
super().__init__(config, device, max_model_len)
|
||||
self.config = config
|
||||
# TODO(gc): later change this to a switch?
|
||||
if True:
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from bigdl.llm import optimize_model
|
||||
|
||||
# low_bit = 'sym_int4'
|
||||
# Always enable bigdl-llm model
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from bigdl.llm import optimize_model
|
||||
if device == 'cpu':
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config._name_or_path,
|
||||
|
|
@ -81,7 +83,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
import intel_extension_for_pytorch as ipex
|
||||
except ImportError:
|
||||
print("Intel Extension for PyTorch is not installed, \
|
||||
but is required for xpu inference.")
|
||||
but is required for xpu inference.")
|
||||
|
||||
low_bit = 'sym_int4'
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
|
|
@ -93,17 +95,19 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
self.model = model.to('xpu')
|
||||
self.sampler = BigDLSampler(config.vocab_size, device).to('xpu')
|
||||
|
||||
if device is None:
|
||||
self.device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
self.device = torch.device(device)
|
||||
self.dtype = self.model.dtype
|
||||
self.last_seq_ids = []
|
||||
self.tmp_kv_cache = None
|
||||
self.last_kv_cache = None
|
||||
self.pad_token_id = config.pad_token_id
|
||||
self.max_seq_limit = max_model_len
|
||||
|
||||
# GC: Note for selective batching
|
||||
# KV_CACHE in the format of num_layers x 2 x (seq_id -> torch.Tensor)
|
||||
# past_key_values in the format of num_layers x len(seq_id) x (2 x torch.Tensor)
|
||||
# If we set num_layers to 9, have 10 sequences in total.
|
||||
# then, for the kv_cache, we get 9 x 2 x 10 = 180 tensors
|
||||
# for past_key_values, we get 9 x 10 x 2 = 180 tensors
|
||||
def forward(
|
||||
self,
|
||||
seq_group_meta_data_lists: List[SequenceGroupMetadata],
|
||||
|
|
@ -116,7 +120,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
decoder_kv_size = 2
|
||||
|
||||
bigdl_input_ids = []
|
||||
bigdl_position_ids = []
|
||||
# bigdl_position_ids = []
|
||||
bigdl_attention_mask = []
|
||||
|
||||
cur_seq_ids = []
|
||||
|
|
@ -144,8 +148,12 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
# 1. Assemble bigdl_input_ids end
|
||||
|
||||
if is_decoding_stage:
|
||||
bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists,
|
||||
kv_cache, num_layers, decoder_kv_size)
|
||||
construct_kv_cache_func = self.get_construct_kv_cache_func(enable_vllm_se_batching)
|
||||
bigdl_kv_cache = construct_kv_cache_func(cur_seq_ids,
|
||||
seq_group_meta_data_lists,
|
||||
kv_cache,
|
||||
num_layers,
|
||||
2)
|
||||
else:
|
||||
bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len)
|
||||
bigdl_input_ids = [
|
||||
|
|
@ -153,41 +161,72 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
for input_ids in bigdl_input_ids
|
||||
]
|
||||
|
||||
decoding_attention_mask_list = []
|
||||
decoding_position_ids = []
|
||||
# num_layers x len(seq_id) x (2 x torch.Tensor)
|
||||
if is_decoding_stage:
|
||||
cur_seq_len = bigdl_kv_cache[0][0].size(2)
|
||||
for seq_group_meta_data in seq_group_meta_data_lists:
|
||||
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||
seq_id = seq_ids[0]
|
||||
seq_data = seq_group_meta_data.seq_data[seq_id]
|
||||
cur_pos = seq_data.get_len()
|
||||
# bigdl_position_ids.append([cur_pos - 1])
|
||||
cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
|
||||
bigdl_attention_mask.append(cur_attention_mask)
|
||||
if enable_vllm_se_batching:
|
||||
batch = 0
|
||||
for seq_group_meta_data in seq_group_meta_data_lists:
|
||||
# Get current seq_len in kv_cache
|
||||
current_seq_len = bigdl_kv_cache[0][batch][0].size(2)
|
||||
batch += 1
|
||||
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||
seq_data = seq_group_meta_data.seq_data[seq_ids[0]]
|
||||
cur_pos = seq_data.get_len()
|
||||
decoding_position_ids.append(cur_pos - 1)
|
||||
# Total length: current_seq_len + 1
|
||||
cur_attention_mask = [0] * (current_seq_len - cur_pos + 1) + [1] * (cur_pos)
|
||||
decoding_attention_mask_list.append(cur_attention_mask)
|
||||
else:
|
||||
cur_seq_len = bigdl_kv_cache[0][0].size(2)
|
||||
for seq_group_meta_data in seq_group_meta_data_lists:
|
||||
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||
seq_id = seq_ids[0]
|
||||
seq_data = seq_group_meta_data.seq_data[seq_id]
|
||||
cur_pos = seq_data.get_len()
|
||||
# bigdl_position_ids.append([cur_pos - 1])
|
||||
# decoding_position_ids.append(cur_pos - 1)
|
||||
cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
|
||||
decoding_attention_mask_list.append(cur_attention_mask)
|
||||
|
||||
bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
|
||||
|
||||
if is_decoding_stage:
|
||||
# bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
|
||||
bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
|
||||
if enable_vllm_se_batching:
|
||||
attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0)
|
||||
for x in decoding_attention_mask_list]
|
||||
position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1)
|
||||
else:
|
||||
attention_mask = torch.tensor(decoding_attention_mask_list, device=self.device)
|
||||
position_ids = None
|
||||
kwargs = {
|
||||
"input_ids": bigdl_input_ids,
|
||||
# "position_ids": bigdl_position_ids,
|
||||
"attention_mask": bigdl_attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": bigdl_kv_cache,
|
||||
"use_cache": True,
|
||||
# "return_dict": True,
|
||||
}
|
||||
else:
|
||||
# Prefill stage
|
||||
attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
|
||||
if enable_vllm_se_batching:
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
else:
|
||||
position_ids = None
|
||||
kwargs = {
|
||||
"input_ids": bigdl_input_ids,
|
||||
"attention_mask": torch.tensor(bigdl_attention_mask, device=self.device),
|
||||
# "position_ids": bigdl_position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": None,
|
||||
"use_cache": True,
|
||||
# "return_dict": True,
|
||||
}
|
||||
# Prefill may need additional space, which forces us to delete the last_kv_cache
|
||||
if self.last_kv_cache:
|
||||
del self.last_kv_cache
|
||||
self.last_kv_cache = None
|
||||
# pdb.set_trace()
|
||||
|
||||
if self.device.type == 'xpu':
|
||||
|
|
@ -207,8 +246,12 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
# tmp = torch.xpu.memory_stats()
|
||||
# logger.info(f"before: {tmp['allocated_bytes.all.current']}")
|
||||
|
||||
self.update_kv_cache(cur_seq_ids,
|
||||
kv_cache, num_layers, decoder_kv_size)
|
||||
if enable_vllm_se_batching:
|
||||
self.update_kv_cache_selective_batching(
|
||||
cur_seq_ids, kv_cache, num_layers, decoder_kv_size)
|
||||
self.last_kv_cache = None
|
||||
else:
|
||||
self.update_kv_cache(cur_seq_ids, kv_cache, num_layers, decoder_kv_size)
|
||||
|
||||
# tmp = torch.xpu.memory_stats()
|
||||
# logger.info(f"after: {tmp['allocated_bytes.all.current']}")
|
||||
|
|
|
|||
|
|
@ -137,6 +137,34 @@ class BigDLModelForCausalLM(nn.Module):
|
|||
|
||||
return bigdl_kv_cache
|
||||
|
||||
def get_construct_kv_cache_func(self, enable_selective_batching):
|
||||
if enable_selective_batching:
|
||||
return self.prepare_kv_cache_selective_batching
|
||||
else:
|
||||
return self.prepare_kv_cache
|
||||
|
||||
# This is an implementation for models that KV Cache shape in (batch_size, num_heads,
|
||||
# sequence_length, embed_size_per_head).
|
||||
def prepare_kv_cache_selective_batching(
|
||||
self,
|
||||
cur_seq_ids: List[int],
|
||||
seq_group_meta_data_lists: List[SequenceGroupMetadata],
|
||||
kv_cache: Dict,
|
||||
num_layers: int,
|
||||
kv_cache_size_1: int,
|
||||
):
|
||||
# Return bigdl_kv_cache in the format of Tuple(List[Tuple(torch.Tensor)])
|
||||
bigdl_kv_cache = []
|
||||
for i in range(num_layers):
|
||||
# Construct a list of tuple(tensor)
|
||||
temp_cache = []
|
||||
for seq_id in cur_seq_ids:
|
||||
key = kv_cache[i][0][seq_id]
|
||||
value = kv_cache[i][1][seq_id]
|
||||
temp_cache.append((key, value))
|
||||
bigdl_kv_cache.append(temp_cache)
|
||||
return bigdl_kv_cache
|
||||
|
||||
# This is an implementation for models that KV Cache shape in (batch_size, num_heads,
|
||||
# sequence_length, embed_size_per_head).
|
||||
def update_kv_cache(
|
||||
|
|
@ -153,6 +181,18 @@ class BigDLModelForCausalLM(nn.Module):
|
|||
kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim]
|
||||
batch_dim = batch_dim + 1
|
||||
|
||||
def update_kv_cache_selective_batching(
|
||||
self,
|
||||
cur_seq_ids: List[int],
|
||||
kv_cache,
|
||||
layer: int,
|
||||
kv_cache_size_1: int,
|
||||
) -> None:
|
||||
for i in range(layer):
|
||||
for j in range(len(cur_seq_ids)):
|
||||
kv_cache[i][0][cur_seq_ids[j]] = self.last_kv_cache[i][j][0]
|
||||
kv_cache[i][1][cur_seq_ids[j]] = self.last_kv_cache[i][j][1]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
seq_group_meta_data_lists: List[SequenceGroupMetadata],
|
||||
|
|
|
|||
Loading…
Reference in a new issue