[LLM] vLLM: fix memory leak in prepare_kv_cache (#9616)
Revert modification in prepare_kv_cache to fix memory leak.
This commit is contained in:
parent
13d47955a8
commit
0327169b50
2 changed files with 24 additions and 35 deletions
|
|
@ -35,7 +35,6 @@ from transformers.generation.logits_process import (
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -161,18 +160,18 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
seq_id = seq_ids[0]
|
seq_id = seq_ids[0]
|
||||||
seq_data = seq_group_meta_data.seq_data[seq_id]
|
seq_data = seq_group_meta_data.seq_data[seq_id]
|
||||||
cur_pos = seq_data.get_len()
|
cur_pos = seq_data.get_len()
|
||||||
bigdl_position_ids.append([cur_pos - 1])
|
# bigdl_position_ids.append([cur_pos - 1])
|
||||||
cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
|
cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
|
||||||
bigdl_attention_mask.append(cur_attention_mask)
|
bigdl_attention_mask.append(cur_attention_mask)
|
||||||
|
|
||||||
bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
|
bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
|
||||||
|
|
||||||
if is_decoding_stage:
|
if is_decoding_stage:
|
||||||
bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
|
# bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
|
||||||
bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
|
bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"input_ids": bigdl_input_ids,
|
"input_ids": bigdl_input_ids,
|
||||||
"position_ids": bigdl_position_ids,
|
# "position_ids": bigdl_position_ids,
|
||||||
"attention_mask": bigdl_attention_mask,
|
"attention_mask": bigdl_attention_mask,
|
||||||
"past_key_values": bigdl_kv_cache,
|
"past_key_values": bigdl_kv_cache,
|
||||||
"use_cache": True,
|
"use_cache": True,
|
||||||
|
|
@ -199,6 +198,7 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
||||||
# self.last_kv_cache = outputs.past_key_values
|
# self.last_kv_cache = outputs.past_key_values
|
||||||
self._set_last_seq_ids(cur_seq_ids[:])
|
self._set_last_seq_ids(cur_seq_ids[:])
|
||||||
self._set_last_kv_cache(outputs.past_key_values)
|
self._set_last_kv_cache(outputs.past_key_values)
|
||||||
|
# pdb.set_trace()
|
||||||
|
|
||||||
logits = outputs.logits[:, -1, :]
|
logits = outputs.logits[:, -1, :]
|
||||||
bigdl_output = self.sampler(logits, input_metadata, st_timestamp)
|
bigdl_output = self.sampler(logits, input_metadata, st_timestamp)
|
||||||
|
|
|
||||||
|
|
@ -108,42 +108,31 @@ class BigDLModelForCausalLM(nn.Module):
|
||||||
else:
|
else:
|
||||||
del self.last_kv_cache
|
del self.last_kv_cache
|
||||||
bigdl_kv_cache = []
|
bigdl_kv_cache = []
|
||||||
max_kv_len = max(kv_cache[0][0][processed_seq_id].size(dim=1)
|
max_kv_len = max(
|
||||||
for processed_seq_id in cur_seq_ids)
|
seq_group_meta_data.seq_data[next(iter(seq_group_meta_data.seq_data))].get_len()
|
||||||
|
for seq_group_meta_data in seq_group_meta_data_lists
|
||||||
|
)
|
||||||
max_kv_len = min(max_kv_len, max_seq_limit)
|
max_kv_len = min(max_kv_len, max_seq_limit)
|
||||||
for layer in range(num_layers):
|
|
||||||
|
for i in range(num_layers):
|
||||||
cur_list = []
|
cur_list = []
|
||||||
for kv in range(kv_cache_size_1):
|
for j in range(kv_cache_size_1):
|
||||||
kv_list = []
|
views = []
|
||||||
# for seq_group_meta_data in seq_group_meta_data_lists:
|
for seq_group_meta_data in seq_group_meta_data_lists:
|
||||||
# seq_ids = list(seq_group_meta_data.seq_data.keys())
|
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||||
# seq_id = seq_ids[0]
|
seq_id = seq_ids[0]
|
||||||
# # seq_data = seq_group_meta_data.seq_data[seq_id]
|
view_size = [1] + list(kv_cache[i][j][seq_id].shape)
|
||||||
# view_size = [1] + list(kv_cache[layer][kv][seq_id].shape)
|
views.append(kv_cache[i][j][seq_id].view(view_size))
|
||||||
# kv_list.append(kv_cache[layer][kv][seq_id].view(view_size))
|
|
||||||
for seq_id in cur_seq_ids:
|
|
||||||
processed_kv_cache = kv_cache[layer][kv][seq_id]
|
|
||||||
# Clean
|
|
||||||
kv_cache[layer][kv][processed_kv_cache] = None
|
|
||||||
if processed_kv_cache.size(dim=1) != max_kv_len:
|
|
||||||
processed_kv_cache = _pad_kv_cache_view(processed_kv_cache, max_kv_len,
|
|
||||||
self.device, 1)
|
|
||||||
# Do padding
|
|
||||||
kv_list.append(processed_kv_cache)
|
|
||||||
current_layer_kv_cache = torch.stack(kv_list, dim=0)
|
|
||||||
kv_list.clear()
|
|
||||||
|
|
||||||
# kv_list = [_pad_kv_cache_view(v, max_kv_len, self.device) for v in kv_list]
|
views = [_pad_kv_cache_view(v, max_kv_len, self.device) for v in views]
|
||||||
# cur_view = torch.cat(kv_list, dim=0)
|
cur_view = torch.cat(views, dim=0)
|
||||||
|
cur_list.append(cur_view)
|
||||||
|
|
||||||
# if cur_view.size(2) > max_seq_limit:
|
for seq_group_meta_data in seq_group_meta_data_lists:
|
||||||
# cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
|
seq_ids = list(seq_group_meta_data.seq_data.keys())
|
||||||
cur_list.append(current_layer_kv_cache)
|
seq_id = seq_ids[0]
|
||||||
|
del kv_cache[i][j][seq_id]
|
||||||
|
|
||||||
# for seq_group_meta_data in seq_group_meta_data_lists:
|
|
||||||
# seq_ids = list(seq_group_meta_data.seq_data.keys())
|
|
||||||
# seq_id = seq_ids[0]
|
|
||||||
# del kv_cache[layer][kv][seq_id]
|
|
||||||
bigdl_kv_cache.append(cur_list)
|
bigdl_kv_cache.append(cur_list)
|
||||||
|
|
||||||
return bigdl_kv_cache
|
return bigdl_kv_cache
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue