[VLLM] Change padding patterns for vLLM & clean code (#9609)
* optimize * fix minor error * optimizations * fix style
This commit is contained in:
parent
89069d6173
commit
6978b2c316
2 changed files with 78 additions and 47 deletions
|
|
@ -26,6 +26,7 @@ from bigdl.llm.vllm.model_executor.models.bigdl_model import BigDLModelForCausal
|
||||||
from bigdl.llm.vllm.logger import init_logger
|
from bigdl.llm.vllm.logger import init_logger
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata
|
||||||
from transformers.generation.logits_process import (
|
from transformers.generation.logits_process import (
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
|
|
@ -39,7 +40,16 @@ logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _pad_to_max(x: List[int], max_len: int, padding_id: int = 0) -> List[int]:
|
def _pad_to_max(x: List[int], max_len: int, padding_id: int = 0) -> List[int]:
|
||||||
return x + [padding_id] * (max_len - len(x))
|
return [padding_id] * (max_len - len(x)) + x
|
||||||
|
|
||||||
|
|
||||||
|
def _get_attention_mask_for_prompts(
|
||||||
|
input_ids: List[List[int]], max_prompt_len: int
|
||||||
|
) -> List[List[int]]:
|
||||||
|
attention_mask = [
|
||||||
|
[0] * (max_prompt_len - len(prompt)) + [1] * len(prompt) for prompt in input_ids
|
||||||
|
]
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
|
|
@ -98,49 +108,53 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
seq_group_meta_data_lists: List[SequenceGroupMetadata],
|
seq_group_meta_data_lists: List[SequenceGroupMetadata],
|
||||||
kv_cache: Optional = None,
|
# kv_cache in the format [[dict() for _ in range(2)] for _ in range(32)]
|
||||||
input_metadata: Optional = None,
|
kv_cache: Optional[List[List[Dict]]] = None,
|
||||||
|
input_metadata: Optional[InputMetadata] = None,
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
kv_cache_size_0 = self.model.config.num_hidden_layers
|
num_layers = self.model.config.num_hidden_layers
|
||||||
kv_cache_size_1 = 2
|
# One for key, one for value
|
||||||
seq_len = len(seq_group_meta_data_lists)
|
decoder_kv_size = 2
|
||||||
|
|
||||||
bigdl_input_ids = []
|
bigdl_input_ids = []
|
||||||
bigdl_position_ids = []
|
bigdl_position_ids = []
|
||||||
bigdl_attention_mask = []
|
bigdl_attention_mask = []
|
||||||
|
|
||||||
cur_seq_ids = []
|
cur_seq_ids = []
|
||||||
bigdl_sampling_params = {}
|
max_prompt_len = 0
|
||||||
max_context_len = 0
|
|
||||||
all_decoding = True
|
# 0. Verify is_prompt or is_decoding
|
||||||
|
is_decoding_stage = not seq_group_meta_data_lists[0].is_prompt
|
||||||
|
|
||||||
|
# 1. Assemble bigdl_input_ids
|
||||||
for seq_group_meta_data in seq_group_meta_data_lists:
|
for seq_group_meta_data in seq_group_meta_data_lists:
|
||||||
req_id = seq_group_meta_data.request_id
|
# req_id = seq_group_meta_data.request_id
|
||||||
all_decoding = all_decoding and (not seq_group_meta_data.is_prompt)
|
# is_decoding_stage = is_decoding_stage and (not seq_group_meta_data.is_prompt)
|
||||||
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||||
seq_id = seq_ids[0]
|
seq_id = seq_ids[0]
|
||||||
cur_seq_ids.append(seq_id)
|
cur_seq_ids.append(seq_id)
|
||||||
seq_data = seq_group_meta_data.seq_data[seq_id]
|
seq_data = seq_group_meta_data.seq_data[seq_id]
|
||||||
|
|
||||||
cur_seq_input_ids = seq_data.get_token_ids()
|
cur_seq_input_ids = seq_data.get_token_ids()
|
||||||
context_len = seq_data.get_len()
|
# context_len = seq_data.get_len()
|
||||||
if seq_group_meta_data.is_prompt:
|
if seq_group_meta_data.is_prompt:
|
||||||
bigdl_input_ids.append(cur_seq_input_ids)
|
bigdl_input_ids.append(cur_seq_input_ids)
|
||||||
max_context_len = max(max_context_len, context_len)
|
max_prompt_len = max(max_prompt_len, seq_data.get_len())
|
||||||
else:
|
else:
|
||||||
bigdl_input_ids.append([cur_seq_input_ids[-1]])
|
bigdl_input_ids.append([cur_seq_input_ids[-1]])
|
||||||
|
# 1. Assemble bigdl_input_ids end
|
||||||
|
|
||||||
bigdl_sampling_params[seq_id] = seq_group_meta_data.sampling_params
|
if is_decoding_stage:
|
||||||
|
|
||||||
if all_decoding:
|
|
||||||
bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists,
|
bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists,
|
||||||
kv_cache, kv_cache_size_0, kv_cache_size_1)
|
kv_cache, num_layers, decoder_kv_size)
|
||||||
else:
|
else:
|
||||||
|
bigdl_attention_mask = _get_attention_mask_for_prompts(bigdl_input_ids, max_prompt_len)
|
||||||
bigdl_input_ids = [
|
bigdl_input_ids = [
|
||||||
_pad_to_max(input_ids, max_context_len, self.pad_token_id)
|
_pad_to_max(input_ids, max_prompt_len, self.pad_token_id)
|
||||||
for input_ids in bigdl_input_ids
|
for input_ids in bigdl_input_ids
|
||||||
]
|
]
|
||||||
|
|
||||||
if all_decoding:
|
if is_decoding_stage:
|
||||||
cur_seq_len = bigdl_kv_cache[0][0].size(2)
|
cur_seq_len = bigdl_kv_cache[0][0].size(2)
|
||||||
for seq_group_meta_data in seq_group_meta_data_lists:
|
for seq_group_meta_data in seq_group_meta_data_lists:
|
||||||
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||||
|
|
@ -152,7 +166,8 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
bigdl_attention_mask.append(cur_attention_mask)
|
bigdl_attention_mask.append(cur_attention_mask)
|
||||||
|
|
||||||
bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
|
bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
|
||||||
if all_decoding:
|
|
||||||
|
if is_decoding_stage:
|
||||||
bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
|
bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
|
||||||
bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
|
bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
|
@ -166,6 +181,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
else:
|
else:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"input_ids": bigdl_input_ids,
|
"input_ids": bigdl_input_ids,
|
||||||
|
"attention_mask": torch.tensor(bigdl_attention_mask, device=self.device),
|
||||||
# "position_ids": bigdl_position_ids,
|
# "position_ids": bigdl_position_ids,
|
||||||
"past_key_values": None,
|
"past_key_values": None,
|
||||||
"use_cache": True,
|
"use_cache": True,
|
||||||
|
|
@ -190,7 +206,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
# logger.info(f"before: {tmp['allocated_bytes.all.current']}")
|
# logger.info(f"before: {tmp['allocated_bytes.all.current']}")
|
||||||
|
|
||||||
self.update_kv_cache(cur_seq_ids,
|
self.update_kv_cache(cur_seq_ids,
|
||||||
kv_cache, kv_cache_size_0, kv_cache_size_1)
|
kv_cache, num_layers, decoder_kv_size)
|
||||||
|
|
||||||
# tmp = torch.xpu.memory_stats()
|
# tmp = torch.xpu.memory_stats()
|
||||||
# logger.info(f"after: {tmp['allocated_bytes.all.current']}")
|
# logger.info(f"after: {tmp['allocated_bytes.all.current']}")
|
||||||
|
|
|
||||||
|
|
@ -90,45 +90,60 @@ class BigDLModelForCausalLM(nn.Module):
|
||||||
cur_seq_ids: List[int],
|
cur_seq_ids: List[int],
|
||||||
seq_group_meta_data_lists: List[SequenceGroupMetadata],
|
seq_group_meta_data_lists: List[SequenceGroupMetadata],
|
||||||
kv_cache: Dict,
|
kv_cache: Dict,
|
||||||
kv_cache_size_0: int,
|
num_layers: int,
|
||||||
kv_cache_size_1: int,
|
kv_cache_size_1: int,
|
||||||
):
|
):
|
||||||
max_seq_limit = self.max_seq_limit
|
max_seq_limit = self.max_seq_limit
|
||||||
if (self.last_kv_cache is not None) and cur_seq_ids == self.last_seq_ids:
|
if (self.last_kv_cache is not None) and cur_seq_ids == self.last_seq_ids:
|
||||||
if self.last_kv_cache[0][0].size(2) < max_seq_limit * 1.5:
|
if self.last_kv_cache[0][0].size(2) < max_seq_limit * 1.5:
|
||||||
bigdl_kv_cache = self.last_kv_cache
|
bigdl_kv_cache = self.last_kv_cache
|
||||||
|
# Immediately set it to None to decrease ref-count
|
||||||
|
self.last_kv_cache = None
|
||||||
else:
|
else:
|
||||||
bigdl_kv_cache = [[tmp.narrow(2, self.last_kv_cache[0][0].size(2)
|
bigdl_kv_cache = [[tmp.narrow(2, self.last_kv_cache[0][0].size(2)
|
||||||
- max_seq_limit, max_seq_limit)
|
- max_seq_limit, max_seq_limit)
|
||||||
for tmp in tmp_list] for tmp_list in self.last_kv_cache]
|
for tmp in tmp_list] for tmp_list in self.last_kv_cache]
|
||||||
del self.last_kv_cache
|
del self.last_kv_cache
|
||||||
|
self.last_kv_cache = None
|
||||||
else:
|
else:
|
||||||
del self.last_kv_cache
|
del self.last_kv_cache
|
||||||
bigdl_kv_cache = []
|
bigdl_kv_cache = []
|
||||||
for i in range(kv_cache_size_0):
|
max_kv_len = max(kv_cache[0][0][processed_seq_id].size(dim=1)
|
||||||
|
for processed_seq_id in cur_seq_ids)
|
||||||
|
max_kv_len = min(max_kv_len, max_seq_limit)
|
||||||
|
for layer in range(num_layers):
|
||||||
cur_list = []
|
cur_list = []
|
||||||
for j in range(kv_cache_size_1):
|
for kv in range(kv_cache_size_1):
|
||||||
views = []
|
kv_list = []
|
||||||
max_len = 0
|
# for seq_group_meta_data in seq_group_meta_data_lists:
|
||||||
for seq_group_meta_data in seq_group_meta_data_lists:
|
# seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||||
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
# seq_id = seq_ids[0]
|
||||||
seq_id = seq_ids[0]
|
# # seq_data = seq_group_meta_data.seq_data[seq_id]
|
||||||
seq_data = seq_group_meta_data.seq_data[seq_id]
|
# view_size = [1] + list(kv_cache[layer][kv][seq_id].shape)
|
||||||
view_size = [1] + list(kv_cache[i][j][seq_id].shape)
|
# kv_list.append(kv_cache[layer][kv][seq_id].view(view_size))
|
||||||
views.append(kv_cache[i][j][seq_id].view(view_size))
|
for seq_id in cur_seq_ids:
|
||||||
max_len = max(max_len, view_size[2])
|
processed_kv_cache = kv_cache[layer][kv][seq_id]
|
||||||
|
# Clean
|
||||||
|
kv_cache[layer][kv][processed_kv_cache] = None
|
||||||
|
if processed_kv_cache.size(dim=1) != max_kv_len:
|
||||||
|
processed_kv_cache = _pad_kv_cache_view(processed_kv_cache, max_kv_len,
|
||||||
|
self.device, 1)
|
||||||
|
# Do padding
|
||||||
|
kv_list.append(processed_kv_cache)
|
||||||
|
current_layer_kv_cache = torch.stack(kv_list, dim=0)
|
||||||
|
kv_list.clear()
|
||||||
|
|
||||||
views = [_pad_kv_cache_view(v, max_len, self.device) for v in views]
|
# kv_list = [_pad_kv_cache_view(v, max_kv_len, self.device) for v in kv_list]
|
||||||
cur_view = torch.cat(views, dim=0)
|
# cur_view = torch.cat(kv_list, dim=0)
|
||||||
|
|
||||||
if cur_view.size(2) > max_seq_limit:
|
# if cur_view.size(2) > max_seq_limit:
|
||||||
cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
|
# cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
|
||||||
cur_list.append(cur_view)
|
cur_list.append(current_layer_kv_cache)
|
||||||
|
|
||||||
for seq_group_meta_data in seq_group_meta_data_lists:
|
# for seq_group_meta_data in seq_group_meta_data_lists:
|
||||||
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
# seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||||
seq_id = seq_ids[0]
|
# seq_id = seq_ids[0]
|
||||||
del kv_cache[i][j][seq_id]
|
# del kv_cache[layer][kv][seq_id]
|
||||||
bigdl_kv_cache.append(cur_list)
|
bigdl_kv_cache.append(cur_list)
|
||||||
|
|
||||||
return bigdl_kv_cache
|
return bigdl_kv_cache
|
||||||
|
|
@ -139,15 +154,15 @@ class BigDLModelForCausalLM(nn.Module):
|
||||||
self,
|
self,
|
||||||
cur_seq_ids: List[int],
|
cur_seq_ids: List[int],
|
||||||
kv_cache,
|
kv_cache,
|
||||||
kv_cache_size_0: int,
|
layer: int,
|
||||||
kv_cache_size_1: int,
|
kv_cache_size_1: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
for i in range(kv_cache_size_0):
|
for i in range(layer):
|
||||||
for j in range(kv_cache_size_1):
|
for j in range(kv_cache_size_1):
|
||||||
index = 0
|
batch_dim = 0
|
||||||
for seq_id in cur_seq_ids:
|
for seq_id in cur_seq_ids:
|
||||||
kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][index]
|
kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][batch_dim]
|
||||||
index = index + 1
|
batch_dim = batch_dim + 1
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue