[LLM] vLLM: fix memory leak in prepare_kv_cache (#9616)
Revert modification in prepare_kv_cache to fix memory leak.
This commit is contained in:
parent
13d47955a8
commit
0327169b50
2 changed files with 24 additions and 35 deletions
|
|
@ -35,7 +35,6 @@ from transformers.generation.logits_process import (
|
|||
TopPLogitsWarper,
|
||||
)
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
|
|
@ -161,18 +160,18 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
seq_id = seq_ids[0]
|
||||
seq_data = seq_group_meta_data.seq_data[seq_id]
|
||||
cur_pos = seq_data.get_len()
|
||||
bigdl_position_ids.append([cur_pos - 1])
|
||||
# bigdl_position_ids.append([cur_pos - 1])
|
||||
cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
|
||||
bigdl_attention_mask.append(cur_attention_mask)
|
||||
|
||||
bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
|
||||
|
||||
if is_decoding_stage:
|
||||
bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
|
||||
# bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
|
||||
bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
|
||||
kwargs = {
|
||||
"input_ids": bigdl_input_ids,
|
||||
"position_ids": bigdl_position_ids,
|
||||
# "position_ids": bigdl_position_ids,
|
||||
"attention_mask": bigdl_attention_mask,
|
||||
"past_key_values": bigdl_kv_cache,
|
||||
"use_cache": True,
|
||||
|
|
@ -199,6 +198,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
# self.last_kv_cache = outputs.past_key_values
|
||||
self._set_last_seq_ids(cur_seq_ids[:])
|
||||
self._set_last_kv_cache(outputs.past_key_values)
|
||||
# pdb.set_trace()
|
||||
|
||||
logits = outputs.logits[:, -1, :]
|
||||
bigdl_output = self.sampler(logits, input_metadata, st_timestamp)
|
||||
|
|
|
|||
|
|
@ -108,42 +108,31 @@ class BigDLModelForCausalLM(nn.Module):
|
|||
else:
|
||||
del self.last_kv_cache
|
||||
bigdl_kv_cache = []
|
||||
max_kv_len = max(kv_cache[0][0][processed_seq_id].size(dim=1)
|
||||
for processed_seq_id in cur_seq_ids)
|
||||
max_kv_len = max(
|
||||
seq_group_meta_data.seq_data[next(iter(seq_group_meta_data.seq_data))].get_len()
|
||||
for seq_group_meta_data in seq_group_meta_data_lists
|
||||
)
|
||||
max_kv_len = min(max_kv_len, max_seq_limit)
|
||||
for layer in range(num_layers):
|
||||
|
||||
for i in range(num_layers):
|
||||
cur_list = []
|
||||
for kv in range(kv_cache_size_1):
|
||||
kv_list = []
|
||||
# for seq_group_meta_data in seq_group_meta_data_lists:
|
||||
# seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||
# seq_id = seq_ids[0]
|
||||
# # seq_data = seq_group_meta_data.seq_data[seq_id]
|
||||
# view_size = [1] + list(kv_cache[layer][kv][seq_id].shape)
|
||||
# kv_list.append(kv_cache[layer][kv][seq_id].view(view_size))
|
||||
for seq_id in cur_seq_ids:
|
||||
processed_kv_cache = kv_cache[layer][kv][seq_id]
|
||||
# Clean
|
||||
kv_cache[layer][kv][processed_kv_cache] = None
|
||||
if processed_kv_cache.size(dim=1) != max_kv_len:
|
||||
processed_kv_cache = _pad_kv_cache_view(processed_kv_cache, max_kv_len,
|
||||
self.device, 1)
|
||||
# Do padding
|
||||
kv_list.append(processed_kv_cache)
|
||||
current_layer_kv_cache = torch.stack(kv_list, dim=0)
|
||||
kv_list.clear()
|
||||
for j in range(kv_cache_size_1):
|
||||
views = []
|
||||
for seq_group_meta_data in seq_group_meta_data_lists:
|
||||
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||
seq_id = seq_ids[0]
|
||||
view_size = [1] + list(kv_cache[i][j][seq_id].shape)
|
||||
views.append(kv_cache[i][j][seq_id].view(view_size))
|
||||
|
||||
# kv_list = [_pad_kv_cache_view(v, max_kv_len, self.device) for v in kv_list]
|
||||
# cur_view = torch.cat(kv_list, dim=0)
|
||||
views = [_pad_kv_cache_view(v, max_kv_len, self.device) for v in views]
|
||||
cur_view = torch.cat(views, dim=0)
|
||||
cur_list.append(cur_view)
|
||||
|
||||
# if cur_view.size(2) > max_seq_limit:
|
||||
# cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
|
||||
cur_list.append(current_layer_kv_cache)
|
||||
for seq_group_meta_data in seq_group_meta_data_lists:
|
||||
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||
seq_id = seq_ids[0]
|
||||
del kv_cache[i][j][seq_id]
|
||||
|
||||
# for seq_group_meta_data in seq_group_meta_data_lists:
|
||||
# seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||
# seq_id = seq_ids[0]
|
||||
# del kv_cache[layer][kv][seq_id]
|
||||
bigdl_kv_cache.append(cur_list)
|
||||
|
||||
return bigdl_kv_cache
|
||||
|
|
|
|||
Loading…
Reference in a new issue