[VLLM] Change padding patterns for vLLM & clean code (#9609)

* optimize

* fix minor error

* optimizations

* fix style
This commit is contained in:
Guancheng Fu 2023-12-06 15:27:26 +08:00 committed by GitHub
parent 89069d6173
commit 6978b2c316
2 changed files with 78 additions and 47 deletions

View file

@ -26,6 +26,7 @@ from bigdl.llm.vllm.model_executor.models.bigdl_model import BigDLModelForCausal
from bigdl.llm.vllm.logger import init_logger
import math
import time
from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
@ -39,7 +40,16 @@ logger = init_logger(__name__)
def _pad_to_max(x: List[int], max_len: int, padding_id: int = 0) -> List[int]:
return x + [padding_id] * (max_len - len(x))
return [padding_id] * (max_len - len(x)) + x
def _get_attention_mask_for_prompts(
input_ids: List[List[int]], max_prompt_len: int
) -> List[List[int]]:
attention_mask = [
[0] * (max_prompt_len - len(prompt)) + [1] * len(prompt) for prompt in input_ids
]
return attention_mask
class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
@ -98,49 +108,53 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
def forward(
self,
seq_group_meta_data_lists: List[SequenceGroupMetadata],
kv_cache: Optional = None,
input_metadata: Optional = None,
# kv_cache in the format [[dict() for _ in range(2)] for _ in range(32)]
kv_cache: Optional[List[List[Dict]]] = None,
input_metadata: Optional[InputMetadata] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
kv_cache_size_0 = self.model.config.num_hidden_layers
kv_cache_size_1 = 2
seq_len = len(seq_group_meta_data_lists)
num_layers = self.model.config.num_hidden_layers
# One for key, one for value
decoder_kv_size = 2
bigdl_input_ids = []
bigdl_position_ids = []
bigdl_attention_mask = []
cur_seq_ids = []
bigdl_sampling_params = {}
max_context_len = 0
all_decoding = True
max_prompt_len = 0
# 0. Verify is_prompt or is_decoding
is_decoding_stage = not seq_group_meta_data_lists[0].is_prompt
# 1. Assemble bigdl_input_ids
for seq_group_meta_data in seq_group_meta_data_lists:
req_id = seq_group_meta_data.request_id
all_decoding = all_decoding and (not seq_group_meta_data.is_prompt)
# req_id = seq_group_meta_data.request_id
# is_decoding_stage = is_decoding_stage and (not seq_group_meta_data.is_prompt)
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_id = seq_ids[0]
cur_seq_ids.append(seq_id)
seq_data = seq_group_meta_data.seq_data[seq_id]
cur_seq_input_ids = seq_data.get_token_ids()
context_len = seq_data.get_len()
# context_len = seq_data.get_len()
if seq_group_meta_data.is_prompt:
bigdl_input_ids.append(cur_seq_input_ids)
max_context_len = max(max_context_len, context_len)
max_prompt_len = max(max_prompt_len, seq_data.get_len())
else:
bigdl_input_ids.append([cur_seq_input_ids[-1]])
# 1. Assemble bigdl_input_ids end
bigdl_sampling_params[seq_id] = seq_group_meta_data.sampling_params
if all_decoding:
if is_decoding_stage:
bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists,
kv_cache, kv_cache_size_0, kv_cache_size_1)
kv_cache, num_layers, decoder_kv_size)
else:
bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len)
bigdl_input_ids = [
_pad_to_max(input_ids, max_context_len, self.pad_token_id)
_pad_to_max(input_ids, max_prompt_len, self.pad_token_id)
for input_ids in bigdl_input_ids
]
if all_decoding:
if is_decoding_stage:
cur_seq_len = bigdl_kv_cache[0][0].size(2)
for seq_group_meta_data in seq_group_meta_data_lists:
seq_ids = list(seq_group_meta_data.seq_data.keys())
@ -152,7 +166,8 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
bigdl_attention_mask.append(cur_attention_mask)
bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
if all_decoding:
if is_decoding_stage:
bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
kwargs = {
@ -166,6 +181,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
else:
kwargs = {
"input_ids": bigdl_input_ids,
"attention_mask": torch.tensor(bigdl_attention_mask, device=self.device),
# "position_ids": bigdl_position_ids,
"past_key_values": None,
"use_cache": True,
@ -190,7 +206,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
# logger.info(f"before: {tmp['allocated_bytes.all.current']}")
self.update_kv_cache(cur_seq_ids,
kv_cache, kv_cache_size_0, kv_cache_size_1)
kv_cache, num_layers, decoder_kv_size)
# tmp = torch.xpu.memory_stats()
# logger.info(f"after: {tmp['allocated_bytes.all.current']}")

View file

@ -90,45 +90,60 @@ class BigDLModelForCausalLM(nn.Module):
cur_seq_ids: List[int],
seq_group_meta_data_lists: List[SequenceGroupMetadata],
kv_cache: Dict,
kv_cache_size_0: int,
num_layers: int,
kv_cache_size_1: int,
):
max_seq_limit = self.max_seq_limit
if (self.last_kv_cache is not None) and cur_seq_ids == self.last_seq_ids:
if self.last_kv_cache[0][0].size(2) < max_seq_limit * 1.5:
bigdl_kv_cache = self.last_kv_cache
# Immediately set it to None to decrease ref-count
self.last_kv_cache = None
else:
bigdl_kv_cache = [[tmp.narrow(2, self.last_kv_cache[0][0].size(2)
- max_seq_limit, max_seq_limit)
for tmp in tmp_list] for tmp_list in self.last_kv_cache]
del self.last_kv_cache
self.last_kv_cache = None
else:
del self.last_kv_cache
bigdl_kv_cache = []
for i in range(kv_cache_size_0):
max_kv_len = max(kv_cache[0][0][processed_seq_id].size(dim=1)
for processed_seq_id in cur_seq_ids)
max_kv_len = min(max_kv_len, max_seq_limit)
for layer in range(num_layers):
cur_list = []
for j in range(kv_cache_size_1):
views = []
max_len = 0
for seq_group_meta_data in seq_group_meta_data_lists:
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_id = seq_ids[0]
seq_data = seq_group_meta_data.seq_data[seq_id]
view_size = [1] + list(kv_cache[i][j][seq_id].shape)
views.append(kv_cache[i][j][seq_id].view(view_size))
max_len = max(max_len, view_size[2])
for kv in range(kv_cache_size_1):
kv_list = []
# for seq_group_meta_data in seq_group_meta_data_lists:
# seq_ids = list(seq_group_meta_data.seq_data.keys())
# seq_id = seq_ids[0]
# # seq_data = seq_group_meta_data.seq_data[seq_id]
# view_size = [1] + list(kv_cache[layer][kv][seq_id].shape)
# kv_list.append(kv_cache[layer][kv][seq_id].view(view_size))
for seq_id in cur_seq_ids:
processed_kv_cache = kv_cache[layer][kv][seq_id]
# Clean
kv_cache[layer][kv][processed_kv_cache] = None
if processed_kv_cache.size(dim=1) != max_kv_len:
processed_kv_cache = _pad_kv_cache_view(processed_kv_cache, max_kv_len,
self.device, 1)
# Do padding
kv_list.append(processed_kv_cache)
current_layer_kv_cache = torch.stack(kv_list, dim=0)
kv_list.clear()
views = [_pad_kv_cache_view(v, max_len, self.device) for v in views]
cur_view = torch.cat(views, dim=0)
# kv_list = [_pad_kv_cache_view(v, max_kv_len, self.device) for v in kv_list]
# cur_view = torch.cat(kv_list, dim=0)
if cur_view.size(2) > max_seq_limit:
cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
cur_list.append(cur_view)
# if cur_view.size(2) > max_seq_limit:
# cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
cur_list.append(current_layer_kv_cache)
for seq_group_meta_data in seq_group_meta_data_lists:
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_id = seq_ids[0]
del kv_cache[i][j][seq_id]
# for seq_group_meta_data in seq_group_meta_data_lists:
# seq_ids = list(seq_group_meta_data.seq_data.keys())
# seq_id = seq_ids[0]
# del kv_cache[layer][kv][seq_id]
bigdl_kv_cache.append(cur_list)
return bigdl_kv_cache
@ -139,15 +154,15 @@ class BigDLModelForCausalLM(nn.Module):
self,
cur_seq_ids: List[int],
kv_cache,
kv_cache_size_0: int,
layer: int,
kv_cache_size_1: int,
) -> None:
for i in range(kv_cache_size_0):
for i in range(layer):
for j in range(kv_cache_size_1):
index = 0
batch_dim = 0
for seq_id in cur_seq_ids:
kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][index]
index = index + 1
kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim]
batch_dim = batch_dim + 1
def forward(
self,