Support compress kv with lookahead (#11752)

* support compress kv with lookahead

* enough kv miss param
This commit is contained in:
Yina Chen 2024-08-09 12:39:57 +03:00 committed by GitHub
parent 93455aac09
commit 4b9c57cc60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 32 additions and 12 deletions

View file

@ -287,7 +287,9 @@ def chatglm2_attention_forward(
else:
from transformers.configuration_utils import PretrainedConfig
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
self.layer_number - 1,
q_len)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_number - 1,
query_states, attention_mask, n_head // n_kv_head,

View file

@ -213,7 +213,9 @@ def chatglm4_attention_forward(
else:
from transformers.configuration_utils import PretrainedConfig
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
self.layer_number - 1,
q_len)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_number - 1,
query_states, attention_mask, n_head // n_kv_head,

View file

@ -127,7 +127,8 @@ def minicpm_attention_forward_original(
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
@ -408,7 +409,8 @@ def minicpm_attention_forward_quantized(
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
@ -821,7 +823,8 @@ def minicpm_attention_forward_original_4_39(
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,

View file

@ -699,7 +699,8 @@ def mistral_attention_forward_4_36_quantized(
original_dtype = hidden_states.dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
seq_len=q_len)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
@ -916,7 +917,9 @@ def mistral_attention_forward_4_36_original(
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
self.layer_idx,
q_len)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
@ -1172,7 +1175,8 @@ def mistral_attention_forward_4_39_original(
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
q_len)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,

View file

@ -135,7 +135,9 @@ def attention_forward(
if past_key_value is not None:
# [CompressKV]
if use_compresskv:
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
self.layer_idx,
q_len)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,

View file

@ -440,7 +440,8 @@ def qwen2_attention_forward(
if past_key_value is not None:
# [CompressKV]
if use_compresskv:
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
q_len)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
@ -471,6 +472,8 @@ def qwen2_attention_forward(
is_causal=True).to(hidden_states.dtype)
elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if use_compresskv:
attention_mask = None
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)

View file

@ -460,12 +460,16 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l
def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False):
if version.parse(trans_version) >= version.parse("4.36.0"):
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache)):
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache,\
DynamicCompressCache
if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache,
DynamicCompressCache)):
if hasattr(past_key_values, "_seen_tokens"):
past_key_values._seen_tokens -= new_cache_size
else:
past_key_values.seen_tokens -= new_cache_size
if isinstance(past_key_values, DynamicCompressCache):
past_key_values.real_kv_len -= new_cache_size
for i, k in enumerate(past_key_values.key_cache):
past_key_values.key_cache[i] = k[:, :, :-new_cache_size, :]