Optimize speculative decoding PVC memory usage (#10329)
* optimize memory * update * update * update * support other models * update * fix style
This commit is contained in:
parent
cc796848ea
commit
9ea499ca68
1 changed files with 60 additions and 2 deletions
|
|
@ -24,8 +24,8 @@ import time
|
|||
import os
|
||||
import copy
|
||||
import logging
|
||||
import warnings
|
||||
import inspect
|
||||
import transformers
|
||||
from packaging import version
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from transformers import top_k_top_p_filtering, GenerationConfig, \
|
||||
LogitsProcessorList, StoppingCriteriaList
|
||||
|
|
@ -367,6 +367,55 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
|
|||
delta_past_value.to(torch.float32)
|
||||
|
||||
|
||||
def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_len=256,
|
||||
model_type="llama"):
|
||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||
extend_kv_cache
|
||||
enough_kv_room = True
|
||||
if model_type not in ["chatglm", "qwen", "baichuan", "llama", "mistral",
|
||||
"gptj", "opt"]:
|
||||
return past_key_values, False
|
||||
cache_k = past_key_values[0][0]
|
||||
if model_type == "chatglm":
|
||||
cache_k = cache_k.permute(1, 2, 0, 3)
|
||||
elif model_type == "qwen":
|
||||
cache_k = cache_k.transpose(1, 2)
|
||||
|
||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value=(cache_k, None),
|
||||
seq_len=max_step_draft)
|
||||
bsz, num_heads, current_seq_len, head_dim = cache_k.shape
|
||||
device = past_key_values[0][0].device
|
||||
if not enough_kv_room:
|
||||
past_key_values = list(past_key_values)
|
||||
for i in range(len(past_key_values)):
|
||||
cache_k = past_key_values[i][0]
|
||||
cache_v = past_key_values[i][1]
|
||||
if model_type == "chatglm":
|
||||
cache_k = cache_k.permute(1, 2, 0, 3)
|
||||
cache_v = cache_v.permute(1, 2, 0, 3)
|
||||
elif model_type == "qwen":
|
||||
cache_k = cache_k.transpose(1, 2)
|
||||
cache_v = cache_v.transpose(1, 2)
|
||||
new_cache_k, new_cache_v = extend_kv_cache(
|
||||
bsz,
|
||||
num_heads, # Support GQA
|
||||
head_dim,
|
||||
cache_k.size(2),
|
||||
current_seq_len + max_step_draft + kv_alloc_block_len,
|
||||
dtype=cache_v.dtype,
|
||||
device=device)
|
||||
new_cache_k[:] = cache_k
|
||||
new_cache_v[:] = cache_v
|
||||
if model_type == "chatglm":
|
||||
past_key_values[i] = (new_cache_k.permute(2, 0, 1, 3),
|
||||
new_cache_v.permute(2, 0, 1, 3))
|
||||
elif model_type == "qwen":
|
||||
past_key_values[i] = (new_cache_k.transpose(1, 2), new_cache_v.transpose(1, 2))
|
||||
else:
|
||||
past_key_values[i] = (new_cache_k, new_cache_v)
|
||||
return past_key_values, not enough_kv_room
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def speculative_generate(self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
|
|
@ -504,6 +553,9 @@ def speculative_generate(self,
|
|||
|
||||
self.clear_benchmarks()
|
||||
|
||||
if self.device.type == 'xpu':
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
# Example:
|
||||
# Target model forward for the first token
|
||||
# Step 1. target_model(prompt) -> a
|
||||
|
|
@ -562,6 +614,10 @@ def speculative_generate(self,
|
|||
past_key_values_storage, _enable_ipex)
|
||||
original_draft_past_key_values = draft_past_key_values
|
||||
else:
|
||||
past_key_values, extend_kv = _check_and_extend_kv_cache(past_key_values,
|
||||
max_step_draft,
|
||||
max_new_tokens - step + 40,
|
||||
self.config.model_type)
|
||||
draft_past_key_values = past_key_values
|
||||
draft_generate_ids[:, 0] = current_input_ids
|
||||
draft_prob_list = []
|
||||
|
|
@ -742,6 +798,8 @@ def speculative_generate(self,
|
|||
output_ids = greedy(logits)
|
||||
if self.device.type == 'xpu':
|
||||
torch.xpu.synchronize()
|
||||
if extend_kv:
|
||||
torch.xpu.empty_cache()
|
||||
toc = time.time()
|
||||
self.verify_time.append(toc - tic)
|
||||
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
|
||||
|
|
|
|||
Loading…
Reference in a new issue