Optimize speculative decoding PVC memory usage (#10329)
* optimize memory * update * update * update * support other models * update * fix style
This commit is contained in:
parent
cc796848ea
commit
9ea499ca68
1 changed files with 60 additions and 2 deletions
|
|
@ -24,8 +24,8 @@ import time
|
||||||
import os
|
import os
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import transformers
|
||||||
import inspect
|
from packaging import version
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
from transformers import top_k_top_p_filtering, GenerationConfig, \
|
from transformers import top_k_top_p_filtering, GenerationConfig, \
|
||||||
LogitsProcessorList, StoppingCriteriaList
|
LogitsProcessorList, StoppingCriteriaList
|
||||||
|
|
@ -367,6 +367,55 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
|
||||||
delta_past_value.to(torch.float32)
|
delta_past_value.to(torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_len=256,
|
||||||
|
model_type="llama"):
|
||||||
|
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||||
|
extend_kv_cache
|
||||||
|
enough_kv_room = True
|
||||||
|
if model_type not in ["chatglm", "qwen", "baichuan", "llama", "mistral",
|
||||||
|
"gptj", "opt"]:
|
||||||
|
return past_key_values, False
|
||||||
|
cache_k = past_key_values[0][0]
|
||||||
|
if model_type == "chatglm":
|
||||||
|
cache_k = cache_k.permute(1, 2, 0, 3)
|
||||||
|
elif model_type == "qwen":
|
||||||
|
cache_k = cache_k.transpose(1, 2)
|
||||||
|
|
||||||
|
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value=(cache_k, None),
|
||||||
|
seq_len=max_step_draft)
|
||||||
|
bsz, num_heads, current_seq_len, head_dim = cache_k.shape
|
||||||
|
device = past_key_values[0][0].device
|
||||||
|
if not enough_kv_room:
|
||||||
|
past_key_values = list(past_key_values)
|
||||||
|
for i in range(len(past_key_values)):
|
||||||
|
cache_k = past_key_values[i][0]
|
||||||
|
cache_v = past_key_values[i][1]
|
||||||
|
if model_type == "chatglm":
|
||||||
|
cache_k = cache_k.permute(1, 2, 0, 3)
|
||||||
|
cache_v = cache_v.permute(1, 2, 0, 3)
|
||||||
|
elif model_type == "qwen":
|
||||||
|
cache_k = cache_k.transpose(1, 2)
|
||||||
|
cache_v = cache_v.transpose(1, 2)
|
||||||
|
new_cache_k, new_cache_v = extend_kv_cache(
|
||||||
|
bsz,
|
||||||
|
num_heads, # Support GQA
|
||||||
|
head_dim,
|
||||||
|
cache_k.size(2),
|
||||||
|
current_seq_len + max_step_draft + kv_alloc_block_len,
|
||||||
|
dtype=cache_v.dtype,
|
||||||
|
device=device)
|
||||||
|
new_cache_k[:] = cache_k
|
||||||
|
new_cache_v[:] = cache_v
|
||||||
|
if model_type == "chatglm":
|
||||||
|
past_key_values[i] = (new_cache_k.permute(2, 0, 1, 3),
|
||||||
|
new_cache_v.permute(2, 0, 1, 3))
|
||||||
|
elif model_type == "qwen":
|
||||||
|
past_key_values[i] = (new_cache_k.transpose(1, 2), new_cache_v.transpose(1, 2))
|
||||||
|
else:
|
||||||
|
past_key_values[i] = (new_cache_k, new_cache_v)
|
||||||
|
return past_key_values, not enough_kv_room
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def speculative_generate(self,
|
def speculative_generate(self,
|
||||||
inputs: Optional[torch.Tensor] = None,
|
inputs: Optional[torch.Tensor] = None,
|
||||||
|
|
@ -504,6 +553,9 @@ def speculative_generate(self,
|
||||||
|
|
||||||
self.clear_benchmarks()
|
self.clear_benchmarks()
|
||||||
|
|
||||||
|
if self.device.type == 'xpu':
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
|
||||||
# Example:
|
# Example:
|
||||||
# Target model forward for the first token
|
# Target model forward for the first token
|
||||||
# Step 1. target_model(prompt) -> a
|
# Step 1. target_model(prompt) -> a
|
||||||
|
|
@ -562,6 +614,10 @@ def speculative_generate(self,
|
||||||
past_key_values_storage, _enable_ipex)
|
past_key_values_storage, _enable_ipex)
|
||||||
original_draft_past_key_values = draft_past_key_values
|
original_draft_past_key_values = draft_past_key_values
|
||||||
else:
|
else:
|
||||||
|
past_key_values, extend_kv = _check_and_extend_kv_cache(past_key_values,
|
||||||
|
max_step_draft,
|
||||||
|
max_new_tokens - step + 40,
|
||||||
|
self.config.model_type)
|
||||||
draft_past_key_values = past_key_values
|
draft_past_key_values = past_key_values
|
||||||
draft_generate_ids[:, 0] = current_input_ids
|
draft_generate_ids[:, 0] = current_input_ids
|
||||||
draft_prob_list = []
|
draft_prob_list = []
|
||||||
|
|
@ -742,6 +798,8 @@ def speculative_generate(self,
|
||||||
output_ids = greedy(logits)
|
output_ids = greedy(logits)
|
||||||
if self.device.type == 'xpu':
|
if self.device.type == 'xpu':
|
||||||
torch.xpu.synchronize()
|
torch.xpu.synchronize()
|
||||||
|
if extend_kv:
|
||||||
|
torch.xpu.empty_cache()
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
self.verify_time.append(toc - tic)
|
self.verify_time.append(toc - tic)
|
||||||
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
|
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue