[LLM] vLLM: fix memory leak in prepare_kv_cache (#9616)

Revert modification in prepare_kv_cache to fix memory leak.
This commit is contained in:
Xiangyu Tian 2023-12-07 10:08:18 +08:00 committed by GitHub
parent 13d47955a8
commit 0327169b50
2 changed files with 24 additions and 35 deletions

View file

@ -35,7 +35,6 @@ from transformers.generation.logits_process import (
TopPLogitsWarper,
)
logger = init_logger(__name__)
@ -161,18 +160,18 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
seq_id = seq_ids[0]
seq_data = seq_group_meta_data.seq_data[seq_id]
cur_pos = seq_data.get_len()
bigdl_position_ids.append([cur_pos - 1])
# bigdl_position_ids.append([cur_pos - 1])
cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
bigdl_attention_mask.append(cur_attention_mask)
bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
if is_decoding_stage:
bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
# bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
kwargs = {
"input_ids": bigdl_input_ids,
"position_ids": bigdl_position_ids,
# "position_ids": bigdl_position_ids,
"attention_mask": bigdl_attention_mask,
"past_key_values": bigdl_kv_cache,
"use_cache": True,
@ -199,6 +198,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
# self.last_kv_cache = outputs.past_key_values
self._set_last_seq_ids(cur_seq_ids[:])
self._set_last_kv_cache(outputs.past_key_values)
# pdb.set_trace()
logits = outputs.logits[:, -1, :]
bigdl_output = self.sampler(logits, input_metadata, st_timestamp)

View file

@ -108,42 +108,31 @@ class BigDLModelForCausalLM(nn.Module):
else:
del self.last_kv_cache
bigdl_kv_cache = []
max_kv_len = max(kv_cache[0][0][processed_seq_id].size(dim=1)
for processed_seq_id in cur_seq_ids)
max_kv_len = max(
seq_group_meta_data.seq_data[next(iter(seq_group_meta_data.seq_data))].get_len()
for seq_group_meta_data in seq_group_meta_data_lists
)
max_kv_len = min(max_kv_len, max_seq_limit)
for layer in range(num_layers):
for i in range(num_layers):
cur_list = []
for kv in range(kv_cache_size_1):
kv_list = []
# for seq_group_meta_data in seq_group_meta_data_lists:
# seq_ids = list(seq_group_meta_data.seq_data.keys())
# seq_id = seq_ids[0]
# # seq_data = seq_group_meta_data.seq_data[seq_id]
# view_size = [1] + list(kv_cache[layer][kv][seq_id].shape)
# kv_list.append(kv_cache[layer][kv][seq_id].view(view_size))
for seq_id in cur_seq_ids:
processed_kv_cache = kv_cache[layer][kv][seq_id]
# Clean
kv_cache[layer][kv][processed_kv_cache] = None
if processed_kv_cache.size(dim=1) != max_kv_len:
processed_kv_cache = _pad_kv_cache_view(processed_kv_cache, max_kv_len,
self.device, 1)
# Do padding
kv_list.append(processed_kv_cache)
current_layer_kv_cache = torch.stack(kv_list, dim=0)
kv_list.clear()
for j in range(kv_cache_size_1):
views = []
for seq_group_meta_data in seq_group_meta_data_lists:
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_id = seq_ids[0]
view_size = [1] + list(kv_cache[i][j][seq_id].shape)
views.append(kv_cache[i][j][seq_id].view(view_size))
# kv_list = [_pad_kv_cache_view(v, max_kv_len, self.device) for v in kv_list]
# cur_view = torch.cat(kv_list, dim=0)
views = [_pad_kv_cache_view(v, max_kv_len, self.device) for v in views]
cur_view = torch.cat(views, dim=0)
cur_list.append(cur_view)
# if cur_view.size(2) > max_seq_limit:
# cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
cur_list.append(current_layer_kv_cache)
for seq_group_meta_data in seq_group_meta_data_lists:
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_id = seq_ids[0]
del kv_cache[i][j][seq_id]
# for seq_group_meta_data in seq_group_meta_data_lists:
# seq_ids = list(seq_group_meta_data.seq_data.keys())
# seq_id = seq_ids[0]
# del kv_cache[layer][kv][seq_id]
bigdl_kv_cache.append(cur_list)
return bigdl_kv_cache