Support compress kv with lookahead (#11752)
* support compress kv with lookahead * enough kv miss param
This commit is contained in:
parent
93455aac09
commit
4b9c57cc60
7 changed files with 32 additions and 12 deletions
|
|
@ -287,7 +287,9 @@ def chatglm2_attention_forward(
|
||||||
else:
|
else:
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
|
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
||||||
|
self.layer_number - 1,
|
||||||
|
q_len)
|
||||||
key_states, value_states = past_key_value.update(
|
key_states, value_states = past_key_value.update(
|
||||||
key_states, value_states, self.layer_number - 1,
|
key_states, value_states, self.layer_number - 1,
|
||||||
query_states, attention_mask, n_head // n_kv_head,
|
query_states, attention_mask, n_head // n_kv_head,
|
||||||
|
|
|
||||||
|
|
@ -213,7 +213,9 @@ def chatglm4_attention_forward(
|
||||||
else:
|
else:
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
|
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
||||||
|
self.layer_number - 1,
|
||||||
|
q_len)
|
||||||
key_states, value_states = past_key_value.update(
|
key_states, value_states = past_key_value.update(
|
||||||
key_states, value_states, self.layer_number - 1,
|
key_states, value_states, self.layer_number - 1,
|
||||||
query_states, attention_mask, n_head // n_kv_head,
|
query_states, attention_mask, n_head // n_kv_head,
|
||||||
|
|
|
||||||
|
|
@ -127,7 +127,8 @@ def minicpm_attention_forward_original(
|
||||||
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||||
|
seq_len=q_len)
|
||||||
no_tp = not self.config.pretraining_tp > 1
|
no_tp = not self.config.pretraining_tp > 1
|
||||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||||
use_fuse_rope,
|
use_fuse_rope,
|
||||||
|
|
@ -408,7 +409,8 @@ def minicpm_attention_forward_quantized(
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||||
|
seq_len=q_len)
|
||||||
no_tp = not self.config.pretraining_tp > 1
|
no_tp = not self.config.pretraining_tp > 1
|
||||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||||
use_fuse_rope,
|
use_fuse_rope,
|
||||||
|
|
@ -821,7 +823,8 @@ def minicpm_attention_forward_original_4_39(
|
||||||
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||||
|
seq_len=q_len)
|
||||||
no_tp = not self.config.pretraining_tp > 1
|
no_tp = not self.config.pretraining_tp > 1
|
||||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||||
use_fuse_rope,
|
use_fuse_rope,
|
||||||
|
|
|
||||||
|
|
@ -699,7 +699,8 @@ def mistral_attention_forward_4_36_quantized(
|
||||||
original_dtype = hidden_states.dtype
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||||
|
seq_len=q_len)
|
||||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||||
use_fuse_rope,
|
use_fuse_rope,
|
||||||
enough_kv_room,
|
enough_kv_room,
|
||||||
|
|
@ -916,7 +917,9 @@ def mistral_attention_forward_4_36_original(
|
||||||
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
||||||
|
self.layer_idx,
|
||||||
|
q_len)
|
||||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||||
use_fuse_rope,
|
use_fuse_rope,
|
||||||
enough_kv_room,
|
enough_kv_room,
|
||||||
|
|
@ -1172,7 +1175,8 @@ def mistral_attention_forward_4_39_original(
|
||||||
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||||
|
q_len)
|
||||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||||
use_fuse_rope,
|
use_fuse_rope,
|
||||||
enough_kv_room,
|
enough_kv_room,
|
||||||
|
|
|
||||||
|
|
@ -135,7 +135,9 @@ def attention_forward(
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# [CompressKV]
|
# [CompressKV]
|
||||||
if use_compresskv:
|
if use_compresskv:
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
||||||
|
self.layer_idx,
|
||||||
|
q_len)
|
||||||
key_states, value_states = past_key_value.update(
|
key_states, value_states = past_key_value.update(
|
||||||
key_states, value_states, self.layer_idx,
|
key_states, value_states, self.layer_idx,
|
||||||
query_states, attention_mask, self.num_key_value_groups,
|
query_states, attention_mask, self.num_key_value_groups,
|
||||||
|
|
|
||||||
|
|
@ -440,7 +440,8 @@ def qwen2_attention_forward(
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# [CompressKV]
|
# [CompressKV]
|
||||||
if use_compresskv:
|
if use_compresskv:
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||||
|
q_len)
|
||||||
key_states, value_states = past_key_value.update(
|
key_states, value_states = past_key_value.update(
|
||||||
key_states, value_states, self.layer_idx,
|
key_states, value_states, self.layer_idx,
|
||||||
query_states, attention_mask, self.num_key_value_groups,
|
query_states, attention_mask, self.num_key_value_groups,
|
||||||
|
|
@ -471,6 +472,8 @@ def qwen2_attention_forward(
|
||||||
is_causal=True).to(hidden_states.dtype)
|
is_causal=True).to(hidden_states.dtype)
|
||||||
elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
|
if use_compresskv:
|
||||||
|
attention_mask = None
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
if isinstance(past_key_value, DynamicFp8Cache):
|
||||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||||
attention_mask)
|
attention_mask)
|
||||||
|
|
|
||||||
|
|
@ -460,12 +460,16 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l
|
||||||
|
|
||||||
def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False):
|
def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False):
|
||||||
if version.parse(trans_version) >= version.parse("4.36.0"):
|
if version.parse(trans_version) >= version.parse("4.36.0"):
|
||||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache,\
|
||||||
if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache)):
|
DynamicCompressCache
|
||||||
|
if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache,
|
||||||
|
DynamicCompressCache)):
|
||||||
if hasattr(past_key_values, "_seen_tokens"):
|
if hasattr(past_key_values, "_seen_tokens"):
|
||||||
past_key_values._seen_tokens -= new_cache_size
|
past_key_values._seen_tokens -= new_cache_size
|
||||||
else:
|
else:
|
||||||
past_key_values.seen_tokens -= new_cache_size
|
past_key_values.seen_tokens -= new_cache_size
|
||||||
|
if isinstance(past_key_values, DynamicCompressCache):
|
||||||
|
past_key_values.real_kv_len -= new_cache_size
|
||||||
|
|
||||||
for i, k in enumerate(past_key_values.key_cache):
|
for i, k in enumerate(past_key_values.key_cache):
|
||||||
past_key_values.key_cache[i] = k[:, :, :-new_cache_size, :]
|
past_key_values.key_cache[i] = k[:, :, :-new_cache_size, :]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue