From 4b9c57cc609f8f9defb52e92dff7b5be30bd483a Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Fri, 9 Aug 2024 12:39:57 +0300 Subject: [PATCH] Support compress kv with lookahead (#11752) * support compress kv with lookahead * enough kv miss param --- .../llm/src/ipex_llm/transformers/models/chatglm2.py | 4 +++- .../llm/src/ipex_llm/transformers/models/chatglm4.py | 4 +++- python/llm/src/ipex_llm/transformers/models/minicpm.py | 9 ++++++--- python/llm/src/ipex_llm/transformers/models/mistral.py | 10 +++++++--- python/llm/src/ipex_llm/transformers/models/phi3.py | 4 +++- python/llm/src/ipex_llm/transformers/models/qwen2.py | 5 ++++- python/llm/src/ipex_llm/transformers/speculative.py | 8 ++++++-- 7 files changed, 32 insertions(+), 12 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index fe504960..dcf55e54 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -287,7 +287,9 @@ def chatglm2_attention_forward( else: from transformers.configuration_utils import PretrainedConfig self.config = self.config if hasattr(self, "config") else PretrainedConfig() - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, + self.layer_number - 1, + q_len) key_states, value_states = past_key_value.update( key_states, value_states, self.layer_number - 1, query_states, attention_mask, n_head // n_kv_head, diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index 46d7d780..53ec5e74 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -213,7 +213,9 @@ def chatglm4_attention_forward( else: from transformers.configuration_utils import PretrainedConfig self.config = self.config if hasattr(self, "config") else PretrainedConfig() - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, + self.layer_number - 1, + q_len) key_states, value_states = past_key_value.update( key_states, value_states, self.layer_number - 1, query_states, attention_mask, n_head // n_kv_head, diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm.py b/python/llm/src/ipex_llm/transformers/models/minicpm.py index 50c0b9ce..fc31c4b7 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm.py @@ -127,7 +127,8 @@ def minicpm_attention_forward_original( use_compresskv = isinstance(past_key_value, DynamicCompressCache) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, + seq_len=q_len) no_tp = not self.config.pretraining_tp > 1 decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, @@ -408,7 +409,8 @@ def minicpm_attention_forward_quantized( bsz, q_len, _ = hidden_states.size() device = hidden_states.device use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, + seq_len=q_len) no_tp = not self.config.pretraining_tp > 1 decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, @@ -821,7 +823,8 @@ def minicpm_attention_forward_original_4_39( use_compresskv = isinstance(past_key_value, DynamicCompressCache) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, + seq_len=q_len) no_tp = not self.config.pretraining_tp > 1 decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 61e507b2..689d9108 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -699,7 +699,8 @@ def mistral_attention_forward_4_36_quantized( original_dtype = hidden_states.dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, + seq_len=q_len) decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, enough_kv_room, @@ -916,7 +917,9 @@ def mistral_attention_forward_4_36_original( use_compresskv = isinstance(past_key_value, DynamicCompressCache) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, + self.layer_idx, + q_len) decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, enough_kv_room, @@ -1172,7 +1175,8 @@ def mistral_attention_forward_4_39_original( use_compresskv = isinstance(past_key_value, DynamicCompressCache) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, + q_len) decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, enough_kv_room, diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 60e3c394..443a9921 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -135,7 +135,9 @@ def attention_forward( if past_key_value is not None: # [CompressKV] if use_compresskv: - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, + self.layer_idx, + q_len) key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, query_states, attention_mask, self.num_key_value_groups, diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 0306bb94..df392662 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -440,7 +440,8 @@ def qwen2_attention_forward( if past_key_value is not None: # [CompressKV] if use_compresskv: - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, + q_len) key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, query_states, attention_mask, self.num_key_value_groups, @@ -471,6 +472,8 @@ def qwen2_attention_forward( is_causal=True).to(hidden_states.dtype) elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states): import xe_addons + if use_compresskv: + attention_mask = None if isinstance(past_key_value, DynamicFp8Cache): attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 6d2e0842..8667da8d 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -460,12 +460,16 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False): if version.parse(trans_version) >= version.parse("4.36.0"): - from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache - if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache)): + from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache,\ + DynamicCompressCache + if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache, + DynamicCompressCache)): if hasattr(past_key_values, "_seen_tokens"): past_key_values._seen_tokens -= new_cache_size else: past_key_values.seen_tokens -= new_cache_size + if isinstance(past_key_values, DynamicCompressCache): + past_key_values.real_kv_len -= new_cache_size for i, k in enumerate(past_key_values.key_cache): past_key_values.key_cache[i] = k[:, :, :-new_cache_size, :]