[VLLM] Change padding patterns for vLLM & clean code (#9609)
* optimize * fix minor error * optimizations * fix style
This commit is contained in:
parent
89069d6173
commit
6978b2c316
2 changed files with 78 additions and 47 deletions
|
|
@ -26,6 +26,7 @@ from bigdl.llm.vllm.model_executor.models.bigdl_model import BigDLModelForCausal
|
|||
from bigdl.llm.vllm.logger import init_logger
|
||||
import math
|
||||
import time
|
||||
from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata
|
||||
from transformers.generation.logits_process import (
|
||||
LogitsProcessorList,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
|
|
@ -39,7 +40,16 @@ logger = init_logger(__name__)
|
|||
|
||||
|
||||
def _pad_to_max(x: List[int], max_len: int, padding_id: int = 0) -> List[int]:
|
||||
return x + [padding_id] * (max_len - len(x))
|
||||
return [padding_id] * (max_len - len(x)) + x
|
||||
|
||||
|
||||
def _get_attention_mask_for_prompts(
|
||||
input_ids: List[List[int]], max_prompt_len: int
|
||||
) -> List[List[int]]:
|
||||
attention_mask = [
|
||||
[0] * (max_prompt_len - len(prompt)) + [1] * len(prompt) for prompt in input_ids
|
||||
]
|
||||
return attention_mask
|
||||
|
||||
|
||||
class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||
|
|
@ -98,49 +108,53 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
def forward(
|
||||
self,
|
||||
seq_group_meta_data_lists: List[SequenceGroupMetadata],
|
||||
kv_cache: Optional = None,
|
||||
input_metadata: Optional = None,
|
||||
# kv_cache in the format [[dict() for _ in range(2)] for _ in range(32)]
|
||||
kv_cache: Optional[List[List[Dict]]] = None,
|
||||
input_metadata: Optional[InputMetadata] = None,
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
kv_cache_size_0 = self.model.config.num_hidden_layers
|
||||
kv_cache_size_1 = 2
|
||||
seq_len = len(seq_group_meta_data_lists)
|
||||
num_layers = self.model.config.num_hidden_layers
|
||||
# One for key, one for value
|
||||
decoder_kv_size = 2
|
||||
|
||||
bigdl_input_ids = []
|
||||
bigdl_position_ids = []
|
||||
bigdl_attention_mask = []
|
||||
|
||||
cur_seq_ids = []
|
||||
bigdl_sampling_params = {}
|
||||
max_context_len = 0
|
||||
all_decoding = True
|
||||
max_prompt_len = 0
|
||||
|
||||
# 0. Verify is_prompt or is_decoding
|
||||
is_decoding_stage = not seq_group_meta_data_lists[0].is_prompt
|
||||
|
||||
# 1. Assemble bigdl_input_ids
|
||||
for seq_group_meta_data in seq_group_meta_data_lists:
|
||||
req_id = seq_group_meta_data.request_id
|
||||
all_decoding = all_decoding and (not seq_group_meta_data.is_prompt)
|
||||
# req_id = seq_group_meta_data.request_id
|
||||
# is_decoding_stage = is_decoding_stage and (not seq_group_meta_data.is_prompt)
|
||||
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||
seq_id = seq_ids[0]
|
||||
cur_seq_ids.append(seq_id)
|
||||
seq_data = seq_group_meta_data.seq_data[seq_id]
|
||||
|
||||
cur_seq_input_ids = seq_data.get_token_ids()
|
||||
context_len = seq_data.get_len()
|
||||
# context_len = seq_data.get_len()
|
||||
if seq_group_meta_data.is_prompt:
|
||||
bigdl_input_ids.append(cur_seq_input_ids)
|
||||
max_context_len = max(max_context_len, context_len)
|
||||
max_prompt_len = max(max_prompt_len, seq_data.get_len())
|
||||
else:
|
||||
bigdl_input_ids.append([cur_seq_input_ids[-1]])
|
||||
# 1. Assemble bigdl_input_ids end
|
||||
|
||||
bigdl_sampling_params[seq_id] = seq_group_meta_data.sampling_params
|
||||
|
||||
if all_decoding:
|
||||
if is_decoding_stage:
|
||||
bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists,
|
||||
kv_cache, kv_cache_size_0, kv_cache_size_1)
|
||||
kv_cache, num_layers, decoder_kv_size)
|
||||
else:
|
||||
bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len)
|
||||
bigdl_input_ids = [
|
||||
_pad_to_max(input_ids, max_context_len, self.pad_token_id)
|
||||
_pad_to_max(input_ids, max_prompt_len, self.pad_token_id)
|
||||
for input_ids in bigdl_input_ids
|
||||
]
|
||||
|
||||
if all_decoding:
|
||||
if is_decoding_stage:
|
||||
cur_seq_len = bigdl_kv_cache[0][0].size(2)
|
||||
for seq_group_meta_data in seq_group_meta_data_lists:
|
||||
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||
|
|
@ -152,7 +166,8 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
bigdl_attention_mask.append(cur_attention_mask)
|
||||
|
||||
bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
|
||||
if all_decoding:
|
||||
|
||||
if is_decoding_stage:
|
||||
bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
|
||||
bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
|
||||
kwargs = {
|
||||
|
|
@ -166,6 +181,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
else:
|
||||
kwargs = {
|
||||
"input_ids": bigdl_input_ids,
|
||||
"attention_mask": torch.tensor(bigdl_attention_mask, device=self.device),
|
||||
# "position_ids": bigdl_position_ids,
|
||||
"past_key_values": None,
|
||||
"use_cache": True,
|
||||
|
|
@ -190,7 +206,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
# logger.info(f"before: {tmp['allocated_bytes.all.current']}")
|
||||
|
||||
self.update_kv_cache(cur_seq_ids,
|
||||
kv_cache, kv_cache_size_0, kv_cache_size_1)
|
||||
kv_cache, num_layers, decoder_kv_size)
|
||||
|
||||
# tmp = torch.xpu.memory_stats()
|
||||
# logger.info(f"after: {tmp['allocated_bytes.all.current']}")
|
||||
|
|
|
|||
|
|
@ -90,45 +90,60 @@ class BigDLModelForCausalLM(nn.Module):
|
|||
cur_seq_ids: List[int],
|
||||
seq_group_meta_data_lists: List[SequenceGroupMetadata],
|
||||
kv_cache: Dict,
|
||||
kv_cache_size_0: int,
|
||||
num_layers: int,
|
||||
kv_cache_size_1: int,
|
||||
):
|
||||
max_seq_limit = self.max_seq_limit
|
||||
if (self.last_kv_cache is not None) and cur_seq_ids == self.last_seq_ids:
|
||||
if self.last_kv_cache[0][0].size(2) < max_seq_limit * 1.5:
|
||||
bigdl_kv_cache = self.last_kv_cache
|
||||
# Immediately set it to None to decrease ref-count
|
||||
self.last_kv_cache = None
|
||||
else:
|
||||
bigdl_kv_cache = [[tmp.narrow(2, self.last_kv_cache[0][0].size(2)
|
||||
- max_seq_limit, max_seq_limit)
|
||||
for tmp in tmp_list] for tmp_list in self.last_kv_cache]
|
||||
del self.last_kv_cache
|
||||
self.last_kv_cache = None
|
||||
else:
|
||||
del self.last_kv_cache
|
||||
bigdl_kv_cache = []
|
||||
for i in range(kv_cache_size_0):
|
||||
max_kv_len = max(kv_cache[0][0][processed_seq_id].size(dim=1)
|
||||
for processed_seq_id in cur_seq_ids)
|
||||
max_kv_len = min(max_kv_len, max_seq_limit)
|
||||
for layer in range(num_layers):
|
||||
cur_list = []
|
||||
for j in range(kv_cache_size_1):
|
||||
views = []
|
||||
max_len = 0
|
||||
for seq_group_meta_data in seq_group_meta_data_lists:
|
||||
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||
seq_id = seq_ids[0]
|
||||
seq_data = seq_group_meta_data.seq_data[seq_id]
|
||||
view_size = [1] + list(kv_cache[i][j][seq_id].shape)
|
||||
views.append(kv_cache[i][j][seq_id].view(view_size))
|
||||
max_len = max(max_len, view_size[2])
|
||||
for kv in range(kv_cache_size_1):
|
||||
kv_list = []
|
||||
# for seq_group_meta_data in seq_group_meta_data_lists:
|
||||
# seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||
# seq_id = seq_ids[0]
|
||||
# # seq_data = seq_group_meta_data.seq_data[seq_id]
|
||||
# view_size = [1] + list(kv_cache[layer][kv][seq_id].shape)
|
||||
# kv_list.append(kv_cache[layer][kv][seq_id].view(view_size))
|
||||
for seq_id in cur_seq_ids:
|
||||
processed_kv_cache = kv_cache[layer][kv][seq_id]
|
||||
# Clean
|
||||
kv_cache[layer][kv][processed_kv_cache] = None
|
||||
if processed_kv_cache.size(dim=1) != max_kv_len:
|
||||
processed_kv_cache = _pad_kv_cache_view(processed_kv_cache, max_kv_len,
|
||||
self.device, 1)
|
||||
# Do padding
|
||||
kv_list.append(processed_kv_cache)
|
||||
current_layer_kv_cache = torch.stack(kv_list, dim=0)
|
||||
kv_list.clear()
|
||||
|
||||
views = [_pad_kv_cache_view(v, max_len, self.device) for v in views]
|
||||
cur_view = torch.cat(views, dim=0)
|
||||
# kv_list = [_pad_kv_cache_view(v, max_kv_len, self.device) for v in kv_list]
|
||||
# cur_view = torch.cat(kv_list, dim=0)
|
||||
|
||||
if cur_view.size(2) > max_seq_limit:
|
||||
cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
|
||||
cur_list.append(cur_view)
|
||||
# if cur_view.size(2) > max_seq_limit:
|
||||
# cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
|
||||
cur_list.append(current_layer_kv_cache)
|
||||
|
||||
for seq_group_meta_data in seq_group_meta_data_lists:
|
||||
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||
seq_id = seq_ids[0]
|
||||
del kv_cache[i][j][seq_id]
|
||||
# for seq_group_meta_data in seq_group_meta_data_lists:
|
||||
# seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||
# seq_id = seq_ids[0]
|
||||
# del kv_cache[layer][kv][seq_id]
|
||||
bigdl_kv_cache.append(cur_list)
|
||||
|
||||
return bigdl_kv_cache
|
||||
|
|
@ -139,15 +154,15 @@ class BigDLModelForCausalLM(nn.Module):
|
|||
self,
|
||||
cur_seq_ids: List[int],
|
||||
kv_cache,
|
||||
kv_cache_size_0: int,
|
||||
layer: int,
|
||||
kv_cache_size_1: int,
|
||||
) -> None:
|
||||
for i in range(kv_cache_size_0):
|
||||
for i in range(layer):
|
||||
for j in range(kv_cache_size_1):
|
||||
index = 0
|
||||
batch_dim = 0
|
||||
for seq_id in cur_seq_ids:
|
||||
kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][index]
|
||||
index = index + 1
|
||||
kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim]
|
||||
batch_dim = batch_dim + 1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Reference in a new issue