Implement selective batching for vLLM (#9659)
* add control to load hf model * finish initial version of selective_batching * temp * finish * Remove print statement * fix error * Apply yang's optimization * a version that works * We need to check kv_cache passed in, this could be an error. TODO: add fast decoding path * format * temp solution: not batching prefill requests * a version that works for prefill batching * format * a solid version: works normally * a temp version * Solid version: remove redundant functions * fix format * format * solid: add option to enable selective_batching * remove logic for using transformer models * format * format * solid: enable argument VLLM_ENABLE_SELECTIVE_BATCHING * format * finish * format
This commit is contained in:
parent
2f36769208
commit
fdf93c9267
4 changed files with 467 additions and 36 deletions
|
|
@ -46,6 +46,7 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from .utils import logger
|
from .utils import logger
|
||||||
from typing import Union
|
from typing import Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -386,6 +387,8 @@ def convert_forward(m, target_m, new_forward):
|
||||||
def _optimize_post(model, lightweight_bmm=False):
|
def _optimize_post(model, lightweight_bmm=False):
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
||||||
|
from bigdl.llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31
|
||||||
|
from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31
|
||||||
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
|
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
|
||||||
from bigdl.llm.transformers.models.llama import llama_mlp_forward
|
from bigdl.llm.transformers.models.llama import llama_mlp_forward
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
@ -396,6 +399,10 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
"supported for further optimizations")
|
"supported for further optimizations")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING")
|
||||||
|
enable_vllm_se_batching = vllm_selective_batching is not None
|
||||||
|
enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true"
|
||||||
|
|
||||||
trans_version = transformers.__version__
|
trans_version = transformers.__version__
|
||||||
if version.parse(trans_version) >= version.parse("4.31.0"):
|
if version.parse(trans_version) >= version.parse("4.31.0"):
|
||||||
convert_forward(
|
convert_forward(
|
||||||
|
|
@ -409,6 +416,17 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
transformers.models.llama.modeling_llama.LlamaMLP,
|
transformers.models.llama.modeling_llama.LlamaMLP,
|
||||||
llama_mlp_forward)
|
llama_mlp_forward)
|
||||||
|
if enable_vllm_se_batching:
|
||||||
|
convert_forward(
|
||||||
|
model,
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel,
|
||||||
|
llama_model_selective_batching_forward_4_31,
|
||||||
|
)
|
||||||
|
convert_forward(
|
||||||
|
model,
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||||
|
llama_attention_selective_batching_forward_4_31,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# todo implement 4.28.0 ~ 4.30.2
|
# todo implement 4.28.0 ~ 4.30.2
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -34,15 +34,17 @@
|
||||||
import torch
|
import torch
|
||||||
import importlib
|
import importlib
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union, List
|
||||||
import math
|
import math
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb
|
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
|
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
|
@ -191,7 +193,6 @@ def llama_attention_forward_4_31(
|
||||||
value_states = [F.linear(hidden_states, value_slices[i])
|
value_states = [F.linear(hidden_states, value_slices[i])
|
||||||
for i in range(self.config.pretraining_tp)]
|
for i in range(self.config.pretraining_tp)]
|
||||||
value_states = torch.cat(value_states, dim=-1)
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
|
|
@ -305,6 +306,167 @@ def llama_attention_forward_4_31(
|
||||||
return attn_output.to(original_dtype), attn_weights, past_key_value
|
return attn_output.to(original_dtype), attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def llama_attention_selective_batching_forward_4_31(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
device = hidden_states.device
|
||||||
|
# for flash attention
|
||||||
|
original_dtype = hidden_states.dtype
|
||||||
|
# TODO: consider this later - flash attention
|
||||||
|
# if not self.training and not hidden_states.requires_grad:
|
||||||
|
# fsdp_flag = check_flash_attention_available(hidden_states)
|
||||||
|
# else:
|
||||||
|
# fsdp_flag = False
|
||||||
|
# if fsdp_flag and q_len > 1:
|
||||||
|
# attention_dtype = torch.float16 # use fp16 for flash attention
|
||||||
|
# else:
|
||||||
|
# attention_dtype = original_dtype
|
||||||
|
|
||||||
|
attention_dtype = original_dtype
|
||||||
|
|
||||||
|
# TODO: decoding fast path
|
||||||
|
# use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
|
# enough_kv_room = is_enough_kv_cache_room(past_key_value[0])
|
||||||
|
# is_q4_0 = self.q_proj.qtype == SYM_INT4
|
||||||
|
# no_tp = not self.config.pretraining_tp > 1
|
||||||
|
# decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
|
||||||
|
# enough_kv_room and bsz * q_len == 1)
|
||||||
|
|
||||||
|
# single batch decoding fast path
|
||||||
|
# forward_qkv takes will perform QKV projection, rotary position embedding
|
||||||
|
# and save the key/value states to cache, then return query states and the
|
||||||
|
# extended key/value cache
|
||||||
|
# if decoding_fast_path:
|
||||||
|
# hidden_states = hidden_states.view(1, -1)
|
||||||
|
# kv_seq_len = past_key_value[0].shape[-2]
|
||||||
|
# cache_k = past_key_value[0]
|
||||||
|
# cache_v = past_key_value[1]
|
||||||
|
# import linear_q4_0
|
||||||
|
# query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
|
||||||
|
# self.q_proj.weight,
|
||||||
|
# self.k_proj.weight,
|
||||||
|
# self.v_proj.weight,
|
||||||
|
# position_ids,
|
||||||
|
# cache_k, cache_v,
|
||||||
|
# self.q_proj.weight.qtype,
|
||||||
|
# kv_seq_len,
|
||||||
|
# self.head_dim)
|
||||||
|
# kv_seq_len += 1
|
||||||
|
|
||||||
|
# else:
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet")
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len,
|
||||||
|
self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len,
|
||||||
|
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len,
|
||||||
|
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value)
|
||||||
|
|
||||||
|
# TODO: fuse_rope
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
|
cos, sin, position_ids, "llama")
|
||||||
|
|
||||||
|
updated_past_key_values = []
|
||||||
|
if past_key_value is not None:
|
||||||
|
batched_attention_output = []
|
||||||
|
# print(f"type of attention_mask is {type(attention_mask)}")
|
||||||
|
for batch in range(bsz):
|
||||||
|
past_k, past_v = past_key_value[batch]
|
||||||
|
current_kv_len = past_k.shape[-2] + 1
|
||||||
|
|
||||||
|
current_key_states = torch.cat([past_k,
|
||||||
|
key_states[batch: batch + 1, :, :, :]], dim=2)
|
||||||
|
current_value_states = torch.cat([past_v,
|
||||||
|
value_states[batch: batch + 1, :, :, :]], dim=2)
|
||||||
|
|
||||||
|
updated_past_key_values.append((current_key_states, current_value_states))
|
||||||
|
|
||||||
|
current_key_states = repeat_kv(current_key_states, self.num_key_value_groups)
|
||||||
|
current_value_states = repeat_kv(current_value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
current_query_states = query_states[batch: batch + 1, :, :, :]
|
||||||
|
attn_output, attn_weights = native_sdp(current_query_states,
|
||||||
|
current_key_states,
|
||||||
|
current_value_states,
|
||||||
|
attention_mask[batch],
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
current_kv_len,
|
||||||
|
self.head_dim,
|
||||||
|
self.num_heads)
|
||||||
|
if attn_output.size() != (1, self.num_heads, 1, self.head_dim):
|
||||||
|
invalidInputError(False,
|
||||||
|
f"`attn_output` should be of size "
|
||||||
|
f"{(1, self.num_heads, 1, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}")
|
||||||
|
batched_attention_output.append(attn_output)
|
||||||
|
# For loop ends
|
||||||
|
# TODO: handle attention_weights later
|
||||||
|
attn_output = torch.concat(batched_attention_output, dim=0)
|
||||||
|
batched_attention_output.clear()
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
invalidInputError(False,
|
||||||
|
f"`attn_output` should be of size "
|
||||||
|
f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}")
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
return attn_output, None, updated_past_key_values
|
||||||
|
|
||||||
|
# TODO: Assume always use_cache
|
||||||
|
# print(f"prefill with batch size {bsz}")
|
||||||
|
for batch in range(bsz):
|
||||||
|
updated_past_key_values.append((key_states[batch: batch + 1, :, :, :],
|
||||||
|
value_states[batch: batch+1, :, :, :]))
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||||
|
dtype=attention_dtype)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||||
|
dtype=attention_dtype)
|
||||||
|
attn_output, attn_weights = native_sdp(query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
bsz,
|
||||||
|
q_len,
|
||||||
|
kv_seq_len,
|
||||||
|
self.head_dim,
|
||||||
|
self.num_heads)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
invalidInputError(False,
|
||||||
|
f"`attn_output` should be of size "
|
||||||
|
f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}")
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
|
||||||
|
|
||||||
|
|
||||||
def check_flash_attention_available(query):
|
def check_flash_attention_available(query):
|
||||||
# check whether ipex flash attention can be used
|
# check whether ipex flash attention can be used
|
||||||
if query.device.type != "xpu":
|
if query.device.type != "xpu":
|
||||||
|
|
@ -371,3 +533,171 @@ def native_sdp(query, key, value, attention_mask,
|
||||||
dtype=torch.float32).to(value.dtype)
|
dtype=torch.float32).to(value.dtype)
|
||||||
attn_output = torch.matmul(attn_weights, value)
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
def llama_model_selective_batching_forward_4_31(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
if output_attentions is not None:
|
||||||
|
output_attentions = output_attentions
|
||||||
|
else:
|
||||||
|
output_attentions = self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
invalidInputError(False,
|
||||||
|
"You cannot specify both decoder_input_ids"
|
||||||
|
" and decoder_inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
invalidInputError(False,
|
||||||
|
"You have to specify either "
|
||||||
|
"decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
|
# seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
# The original position_ids in the format of [1, 1]
|
||||||
|
# However, this only applies when kv_len is the same for all the sequences
|
||||||
|
# We should set it to format of [batch, position_id]
|
||||||
|
# TODO: validate correctness
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
if position_ids is None:
|
||||||
|
invalidInputError("vLLM: position_ids should never be None")
|
||||||
|
else:
|
||||||
|
# print(f"Original position_ids is {position_ids}")
|
||||||
|
position_ids = position_ids.view(-1, seq_length)
|
||||||
|
# print(f"after position_ids is {position_ids}")
|
||||||
|
# if past_key_values is None:
|
||||||
|
# # For prefill
|
||||||
|
# position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
|
||||||
|
# position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
# else:
|
||||||
|
# past_key_values_length = []
|
||||||
|
# for sequence_kv in past_key_values[0]:
|
||||||
|
# key = sequence_kv[0]
|
||||||
|
# past_key_values_length.append(key.shape[-2])
|
||||||
|
# position_ids = torch.tensor(past_key_values_length, dtype=torch.long, device=device)
|
||||||
|
# position_ids = position_ids.unsqueeze(0).view(-1, 1)
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
# past_key_values in the format of num_layers x num_seqs x 2
|
||||||
|
# TODO: this may be incorrect
|
||||||
|
past_key_values_length = past_key_values[0][0][0].shape[2]
|
||||||
|
# seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
|
# if position_ids is None:
|
||||||
|
# device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
# # [start, end)
|
||||||
|
# position_ids = torch.arange(
|
||||||
|
# past_key_values_length, seq_length +
|
||||||
|
# past_key_values_length, dtype=torch.long, device=device
|
||||||
|
# )
|
||||||
|
# position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
# else:
|
||||||
|
# position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
# embed positions
|
||||||
|
if attention_mask is None:
|
||||||
|
invalidInputError(False, "attention_mask should never be None")
|
||||||
|
# print(f"attention_mask before expanding: {attention_mask}")
|
||||||
|
if past_key_values is None:
|
||||||
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
i = 0
|
||||||
|
for attn_mask in attention_mask:
|
||||||
|
past_key_value_length = past_key_values[0][i][0].shape[2]
|
||||||
|
new_mask = self._prepare_decoder_attention_mask(
|
||||||
|
attn_mask, (1, seq_length), inputs_embeds, past_key_value_length
|
||||||
|
)
|
||||||
|
attention_mask[i] = new_mask
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
invalidInputError(False, "gradient_checkpointing is not supported")
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs, output_attentions, None)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(decoder_layer),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) # noqa
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ from bigdl.llm.vllm.logger import init_logger
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata
|
from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
import os
|
||||||
from transformers.generation.logits_process import (
|
from transformers.generation.logits_process import (
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
|
|
@ -50,6 +51,10 @@ def _get_attention_mask_for_prompts(
|
||||||
]
|
]
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING")
|
||||||
|
enable_vllm_se_batching = vllm_selective_batching is not None
|
||||||
|
enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
|
|
||||||
|
|
@ -61,12 +66,9 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
):
|
):
|
||||||
super().__init__(config, device, max_model_len)
|
super().__init__(config, device, max_model_len)
|
||||||
self.config = config
|
self.config = config
|
||||||
# TODO(gc): later change this to a switch?
|
# Always enable bigdl-llm model
|
||||||
if True:
|
|
||||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||||
from bigdl.llm import optimize_model
|
from bigdl.llm import optimize_model
|
||||||
|
|
||||||
# low_bit = 'sym_int4'
|
|
||||||
if device == 'cpu':
|
if device == 'cpu':
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
config._name_or_path,
|
config._name_or_path,
|
||||||
|
|
@ -93,17 +95,19 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
self.model = model.to('xpu')
|
self.model = model.to('xpu')
|
||||||
self.sampler = BigDLSampler(config.vocab_size, device).to('xpu')
|
self.sampler = BigDLSampler(config.vocab_size, device).to('xpu')
|
||||||
|
|
||||||
if device is None:
|
|
||||||
self.device = torch.device(
|
|
||||||
"cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
else:
|
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
self.dtype = self.model.dtype
|
self.dtype = self.model.dtype
|
||||||
self.last_seq_ids = []
|
self.last_seq_ids = []
|
||||||
self.tmp_kv_cache = None
|
self.last_kv_cache = None
|
||||||
self.pad_token_id = config.pad_token_id
|
self.pad_token_id = config.pad_token_id
|
||||||
self.max_seq_limit = max_model_len
|
self.max_seq_limit = max_model_len
|
||||||
|
|
||||||
|
# GC: Note for selective batching
|
||||||
|
# KV_CACHE in the format of num_layers x 2 x (seq_id -> torch.Tensor)
|
||||||
|
# past_key_values in the format of num_layers x len(seq_id) x (2 x torch.Tensor)
|
||||||
|
# If we set num_layers to 9, have 10 sequences in total.
|
||||||
|
# then, for the kv_cache, we get 9 x 2 x 10 = 180 tensors
|
||||||
|
# for past_key_values, we get 9 x 10 x 2 = 180 tensors
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
seq_group_meta_data_lists: List[SequenceGroupMetadata],
|
seq_group_meta_data_lists: List[SequenceGroupMetadata],
|
||||||
|
|
@ -116,7 +120,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
decoder_kv_size = 2
|
decoder_kv_size = 2
|
||||||
|
|
||||||
bigdl_input_ids = []
|
bigdl_input_ids = []
|
||||||
bigdl_position_ids = []
|
# bigdl_position_ids = []
|
||||||
bigdl_attention_mask = []
|
bigdl_attention_mask = []
|
||||||
|
|
||||||
cur_seq_ids = []
|
cur_seq_ids = []
|
||||||
|
|
@ -144,8 +148,12 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
# 1. Assemble bigdl_input_ids end
|
# 1. Assemble bigdl_input_ids end
|
||||||
|
|
||||||
if is_decoding_stage:
|
if is_decoding_stage:
|
||||||
bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists,
|
construct_kv_cache_func = self.get_construct_kv_cache_func(enable_vllm_se_batching)
|
||||||
kv_cache, num_layers, decoder_kv_size)
|
bigdl_kv_cache = construct_kv_cache_func(cur_seq_ids,
|
||||||
|
seq_group_meta_data_lists,
|
||||||
|
kv_cache,
|
||||||
|
num_layers,
|
||||||
|
2)
|
||||||
else:
|
else:
|
||||||
bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len)
|
bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len)
|
||||||
bigdl_input_ids = [
|
bigdl_input_ids = [
|
||||||
|
|
@ -153,7 +161,24 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
for input_ids in bigdl_input_ids
|
for input_ids in bigdl_input_ids
|
||||||
]
|
]
|
||||||
|
|
||||||
|
decoding_attention_mask_list = []
|
||||||
|
decoding_position_ids = []
|
||||||
|
# num_layers x len(seq_id) x (2 x torch.Tensor)
|
||||||
if is_decoding_stage:
|
if is_decoding_stage:
|
||||||
|
if enable_vllm_se_batching:
|
||||||
|
batch = 0
|
||||||
|
for seq_group_meta_data in seq_group_meta_data_lists:
|
||||||
|
# Get current seq_len in kv_cache
|
||||||
|
current_seq_len = bigdl_kv_cache[0][batch][0].size(2)
|
||||||
|
batch += 1
|
||||||
|
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||||
|
seq_data = seq_group_meta_data.seq_data[seq_ids[0]]
|
||||||
|
cur_pos = seq_data.get_len()
|
||||||
|
decoding_position_ids.append(cur_pos - 1)
|
||||||
|
# Total length: current_seq_len + 1
|
||||||
|
cur_attention_mask = [0] * (current_seq_len - cur_pos + 1) + [1] * (cur_pos)
|
||||||
|
decoding_attention_mask_list.append(cur_attention_mask)
|
||||||
|
else:
|
||||||
cur_seq_len = bigdl_kv_cache[0][0].size(2)
|
cur_seq_len = bigdl_kv_cache[0][0].size(2)
|
||||||
for seq_group_meta_data in seq_group_meta_data_lists:
|
for seq_group_meta_data in seq_group_meta_data_lists:
|
||||||
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||||
|
|
@ -161,33 +186,47 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
seq_data = seq_group_meta_data.seq_data[seq_id]
|
seq_data = seq_group_meta_data.seq_data[seq_id]
|
||||||
cur_pos = seq_data.get_len()
|
cur_pos = seq_data.get_len()
|
||||||
# bigdl_position_ids.append([cur_pos - 1])
|
# bigdl_position_ids.append([cur_pos - 1])
|
||||||
|
# decoding_position_ids.append(cur_pos - 1)
|
||||||
cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
|
cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
|
||||||
bigdl_attention_mask.append(cur_attention_mask)
|
decoding_attention_mask_list.append(cur_attention_mask)
|
||||||
|
|
||||||
bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
|
bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
|
||||||
|
|
||||||
if is_decoding_stage:
|
if is_decoding_stage:
|
||||||
# bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
|
if enable_vllm_se_batching:
|
||||||
bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
|
attention_mask = [torch.tensor(x, device=self.device).unsqueeze(0)
|
||||||
|
for x in decoding_attention_mask_list]
|
||||||
|
position_ids = torch.tensor(decoding_position_ids).long().unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
attention_mask = torch.tensor(decoding_attention_mask_list, device=self.device)
|
||||||
|
position_ids = None
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"input_ids": bigdl_input_ids,
|
"input_ids": bigdl_input_ids,
|
||||||
# "position_ids": bigdl_position_ids,
|
"position_ids": position_ids,
|
||||||
"attention_mask": bigdl_attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"past_key_values": bigdl_kv_cache,
|
"past_key_values": bigdl_kv_cache,
|
||||||
"use_cache": True,
|
"use_cache": True,
|
||||||
# "return_dict": True,
|
# "return_dict": True,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
# Prefill stage
|
||||||
|
attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
|
||||||
|
if enable_vllm_se_batching:
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
else:
|
||||||
|
position_ids = None
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"input_ids": bigdl_input_ids,
|
"input_ids": bigdl_input_ids,
|
||||||
"attention_mask": torch.tensor(bigdl_attention_mask, device=self.device),
|
"attention_mask": attention_mask,
|
||||||
# "position_ids": bigdl_position_ids,
|
"position_ids": position_ids,
|
||||||
"past_key_values": None,
|
"past_key_values": None,
|
||||||
"use_cache": True,
|
"use_cache": True,
|
||||||
# "return_dict": True,
|
# "return_dict": True,
|
||||||
}
|
}
|
||||||
|
# Prefill may need additional space, which forces us to delete the last_kv_cache
|
||||||
if self.last_kv_cache:
|
if self.last_kv_cache:
|
||||||
del self.last_kv_cache
|
self.last_kv_cache = None
|
||||||
# pdb.set_trace()
|
# pdb.set_trace()
|
||||||
|
|
||||||
if self.device.type == 'xpu':
|
if self.device.type == 'xpu':
|
||||||
|
|
@ -207,8 +246,12 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
# tmp = torch.xpu.memory_stats()
|
# tmp = torch.xpu.memory_stats()
|
||||||
# logger.info(f"before: {tmp['allocated_bytes.all.current']}")
|
# logger.info(f"before: {tmp['allocated_bytes.all.current']}")
|
||||||
|
|
||||||
self.update_kv_cache(cur_seq_ids,
|
if enable_vllm_se_batching:
|
||||||
kv_cache, num_layers, decoder_kv_size)
|
self.update_kv_cache_selective_batching(
|
||||||
|
cur_seq_ids, kv_cache, num_layers, decoder_kv_size)
|
||||||
|
self.last_kv_cache = None
|
||||||
|
else:
|
||||||
|
self.update_kv_cache(cur_seq_ids, kv_cache, num_layers, decoder_kv_size)
|
||||||
|
|
||||||
# tmp = torch.xpu.memory_stats()
|
# tmp = torch.xpu.memory_stats()
|
||||||
# logger.info(f"after: {tmp['allocated_bytes.all.current']}")
|
# logger.info(f"after: {tmp['allocated_bytes.all.current']}")
|
||||||
|
|
|
||||||
|
|
@ -137,6 +137,34 @@ class BigDLModelForCausalLM(nn.Module):
|
||||||
|
|
||||||
return bigdl_kv_cache
|
return bigdl_kv_cache
|
||||||
|
|
||||||
|
def get_construct_kv_cache_func(self, enable_selective_batching):
|
||||||
|
if enable_selective_batching:
|
||||||
|
return self.prepare_kv_cache_selective_batching
|
||||||
|
else:
|
||||||
|
return self.prepare_kv_cache
|
||||||
|
|
||||||
|
# This is an implementation for models that KV Cache shape in (batch_size, num_heads,
|
||||||
|
# sequence_length, embed_size_per_head).
|
||||||
|
def prepare_kv_cache_selective_batching(
|
||||||
|
self,
|
||||||
|
cur_seq_ids: List[int],
|
||||||
|
seq_group_meta_data_lists: List[SequenceGroupMetadata],
|
||||||
|
kv_cache: Dict,
|
||||||
|
num_layers: int,
|
||||||
|
kv_cache_size_1: int,
|
||||||
|
):
|
||||||
|
# Return bigdl_kv_cache in the format of Tuple(List[Tuple(torch.Tensor)])
|
||||||
|
bigdl_kv_cache = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
# Construct a list of tuple(tensor)
|
||||||
|
temp_cache = []
|
||||||
|
for seq_id in cur_seq_ids:
|
||||||
|
key = kv_cache[i][0][seq_id]
|
||||||
|
value = kv_cache[i][1][seq_id]
|
||||||
|
temp_cache.append((key, value))
|
||||||
|
bigdl_kv_cache.append(temp_cache)
|
||||||
|
return bigdl_kv_cache
|
||||||
|
|
||||||
# This is an implementation for models that KV Cache shape in (batch_size, num_heads,
|
# This is an implementation for models that KV Cache shape in (batch_size, num_heads,
|
||||||
# sequence_length, embed_size_per_head).
|
# sequence_length, embed_size_per_head).
|
||||||
def update_kv_cache(
|
def update_kv_cache(
|
||||||
|
|
@ -153,6 +181,18 @@ class BigDLModelForCausalLM(nn.Module):
|
||||||
kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim]
|
kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim]
|
||||||
batch_dim = batch_dim + 1
|
batch_dim = batch_dim + 1
|
||||||
|
|
||||||
|
def update_kv_cache_selective_batching(
|
||||||
|
self,
|
||||||
|
cur_seq_ids: List[int],
|
||||||
|
kv_cache,
|
||||||
|
layer: int,
|
||||||
|
kv_cache_size_1: int,
|
||||||
|
) -> None:
|
||||||
|
for i in range(layer):
|
||||||
|
for j in range(len(cur_seq_ids)):
|
||||||
|
kv_cache[i][0][cur_seq_ids[j]] = self.last_kv_cache[i][j][0]
|
||||||
|
kv_cache[i][1][cur_seq_ids[j]] = self.last_kv_cache[i][j][1]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
seq_group_meta_data_lists: List[SequenceGroupMetadata],
|
seq_group_meta_data_lists: List[SequenceGroupMetadata],
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue