Support compress kv with lookahead (#11752)
* support compress kv with lookahead * enough kv miss param
This commit is contained in:
parent
93455aac09
commit
4b9c57cc60
7 changed files with 32 additions and 12 deletions
|
|
@ -287,7 +287,9 @@ def chatglm2_attention_forward(
|
|||
else:
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
||||
self.layer_number - 1,
|
||||
q_len)
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_number - 1,
|
||||
query_states, attention_mask, n_head // n_kv_head,
|
||||
|
|
|
|||
|
|
@ -213,7 +213,9 @@ def chatglm4_attention_forward(
|
|||
else:
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
||||
self.layer_number - 1,
|
||||
q_len)
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_number - 1,
|
||||
query_states, attention_mask, n_head // n_kv_head,
|
||||
|
|
|
|||
|
|
@ -127,7 +127,8 @@ def minicpm_attention_forward_original(
|
|||
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||
seq_len=q_len)
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
use_fuse_rope,
|
||||
|
|
@ -408,7 +409,8 @@ def minicpm_attention_forward_quantized(
|
|||
bsz, q_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||
seq_len=q_len)
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
use_fuse_rope,
|
||||
|
|
@ -821,7 +823,8 @@ def minicpm_attention_forward_original_4_39(
|
|||
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||
seq_len=q_len)
|
||||
no_tp = not self.config.pretraining_tp > 1
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
use_fuse_rope,
|
||||
|
|
|
|||
|
|
@ -699,7 +699,8 @@ def mistral_attention_forward_4_36_quantized(
|
|||
original_dtype = hidden_states.dtype
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||
seq_len=q_len)
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
|
|
@ -916,7 +917,9 @@ def mistral_attention_forward_4_36_original(
|
|||
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
||||
self.layer_idx,
|
||||
q_len)
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
|
|
@ -1172,7 +1175,8 @@ def mistral_attention_forward_4_39_original(
|
|||
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||
q_len)
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
use_fuse_rope,
|
||||
enough_kv_room,
|
||||
|
|
|
|||
|
|
@ -135,7 +135,9 @@ def attention_forward(
|
|||
if past_key_value is not None:
|
||||
# [CompressKV]
|
||||
if use_compresskv:
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
||||
self.layer_idx,
|
||||
q_len)
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx,
|
||||
query_states, attention_mask, self.num_key_value_groups,
|
||||
|
|
|
|||
|
|
@ -440,7 +440,8 @@ def qwen2_attention_forward(
|
|||
if past_key_value is not None:
|
||||
# [CompressKV]
|
||||
if use_compresskv:
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||
q_len)
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx,
|
||||
query_states, attention_mask, self.num_key_value_groups,
|
||||
|
|
@ -471,6 +472,8 @@ def qwen2_attention_forward(
|
|||
is_causal=True).to(hidden_states.dtype)
|
||||
elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||
import xe_addons
|
||||
if use_compresskv:
|
||||
attention_mask = None
|
||||
if isinstance(past_key_value, DynamicFp8Cache):
|
||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
|
|
|
|||
|
|
@ -460,12 +460,16 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l
|
|||
|
||||
def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False):
|
||||
if version.parse(trans_version) >= version.parse("4.36.0"):
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
||||
if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache)):
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache,\
|
||||
DynamicCompressCache
|
||||
if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache,
|
||||
DynamicCompressCache)):
|
||||
if hasattr(past_key_values, "_seen_tokens"):
|
||||
past_key_values._seen_tokens -= new_cache_size
|
||||
else:
|
||||
past_key_values.seen_tokens -= new_cache_size
|
||||
if isinstance(past_key_values, DynamicCompressCache):
|
||||
past_key_values.real_kv_len -= new_cache_size
|
||||
|
||||
for i, k in enumerate(past_key_values.key_cache):
|
||||
past_key_values.key_cache[i] = k[:, :, :-new_cache_size, :]
|
||||
|
|
|
|||
Loading…
Reference in a new issue