From 670ad887fc0561750ce673b6927f8a3b4744d0c4 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Tue, 30 Jul 2024 06:16:42 +0300 Subject: [PATCH] Qwen support compress kv (#11680) * Qwen support compress kv * fix style * fix --- python/llm/src/ipex_llm/transformers/kv.py | 23 +++++++++++++++- .../src/ipex_llm/transformers/models/llama.py | 12 ++++----- .../ipex_llm/transformers/models/mistral.py | 12 ++++----- .../src/ipex_llm/transformers/models/qwen2.py | 27 +++++++++++++++---- 4 files changed, 56 insertions(+), 18 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/kv.py b/python/llm/src/ipex_llm/transformers/kv.py index c6e79dc6..ae253239 100644 --- a/python/llm/src/ipex_llm/transformers/kv.py +++ b/python/llm/src/ipex_llm/transformers/kv.py @@ -259,7 +259,28 @@ class DynamicCompressCache(DynamicCache): num_key_value_groups=num_key_value_groups) self.key_cache.append(key_states_compress) self.value_cache.append(value_states_compress) - return key_states, value_states + + k_cache_compressed, v_cache_compressed = init_kv_cache( + bsz, num_heads, head_dim, + 0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH, + key_states.dtype, key_states.device + ) + k_cache_compressed, v_cache_compressed = append_kv_cache( + k_cache_compressed, v_cache_compressed, + key_states_compress, value_states_compress) + self.key_cache[layer_idx] = k_cache_compressed + self.value_cache[layer_idx] = v_cache_compressed + + if key_states.stride(2) != head_dim: + k_cache, v_cache = init_kv_cache( + bsz, num_heads, head_dim, + 0, key_states.size(2), + key_states.dtype, key_states.device + ) + k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states) + return k_cache, v_cache + else: + return key_states, value_states else: cache_k = self.key_cache[layer_idx] cache_v = self.value_cache[layer_idx] diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 9bc41e89..c02bfb3d 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1289,7 +1289,7 @@ def llama_attention_forward_4_41_original( # for flash attention original_dtype = hidden_states.dtype - # [SnapKV] + # [CompressKV] use_compresskv = should_use_compresskv(hidden_states) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) @@ -1324,7 +1324,7 @@ def llama_attention_forward_4_41_original( self.rotary_emb.base,) kv_seq_len += 1 # update past_key_value's seem_tokens and kv caches. - # [SnapKV] + # [CompressKV] if use_compresskv: past_key_value.update_seen_tokens(self.layer_idx, q_len) kv_seq_len = past_key_value.get_seq_length() @@ -1496,7 +1496,7 @@ def llama_attention_forward_4_41_original( use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import xe_addons if use_compresskv: - # [SnapKV] set attention_mask = None + # [CompressKV] set attention_mask = None new_attention_mask = None attn_output = xe_addons.sdp(query_states, key_states, value_states, new_attention_mask) @@ -1833,7 +1833,7 @@ def llama_attention_forward_4_38_original( # for flash attention original_dtype = hidden_states.dtype - # [SnapKV] + # [CompressKV] use_compresskv = should_use_compresskv(hidden_states) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) @@ -1868,7 +1868,7 @@ def llama_attention_forward_4_38_original( self.rotary_emb.base,) kv_seq_len += 1 # update past_key_value's seem_tokens and kv caches. - # [SnapKV] + # [CompressKV] if use_compresskv: past_key_value.update_seen_tokens(self.layer_idx, q_len) kv_seq_len = past_key_value.get_seq_length() @@ -2039,7 +2039,7 @@ def llama_attention_forward_4_38_original( use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import xe_addons if use_compresskv: - # [SnapKV] set attention_mask = None + # [CompressKV] set attention_mask = None new_attention_mask = None attn_output = xe_addons.sdp(query_states, key_states, value_states, new_attention_mask) diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 93825592..a3ed08de 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -897,7 +897,7 @@ def mistral_attention_forward_4_36_original( # for flash attention original_dtype = hidden_states.dtype - # [SnapKV] + # [CompressKV] use_compresskv = should_use_compresskv(hidden_states) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) @@ -930,7 +930,7 @@ def mistral_attention_forward_4_36_original( kv_seq_len += 1 # update past_key_value's seem_tokens and kv caches. - # [SnapKV] + # [CompressKV] if use_compresskv: past_key_value.update_seen_tokens(self.layer_idx, q_len) kv_seq_len = past_key_value.get_seq_length() @@ -1055,7 +1055,7 @@ def mistral_attention_forward_4_36_original( elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): # new fp16 sdp doesn't require repeat_kv import xe_addons - # [SnapKV] set attention_mask = None + # [CompressKV] set attention_mask = None if use_compresskv: attention_mask = None attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) @@ -1142,7 +1142,7 @@ def mistral_attention_forward_4_39_original( # for flash attention original_dtype = hidden_states.dtype - # [SnapKV] + # [CompressKV] use_compresskv = should_use_compresskv(hidden_states) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) @@ -1175,7 +1175,7 @@ def mistral_attention_forward_4_39_original( kv_seq_len += 1 # update past_key_value's seem_tokens and kv caches. - # [SnapKV] + # [CompressKV] if use_compresskv: past_key_value.update_seen_tokens(self.layer_idx, q_len) kv_seq_len = past_key_value.get_seq_length() @@ -1300,7 +1300,7 @@ def mistral_attention_forward_4_39_original( elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): # new fp16 sdp doesn't require repeat_kv import xe_addons - # [SnapKV] set attention_mask = None + # [CompressKV] set attention_mask = None if use_compresskv: attention_mask = None attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 4b0ad99c..4bf7ae1b 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -47,9 +47,10 @@ from torch.nn.functional import scaled_dot_product_attention as sdpa from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check from ipex_llm.transformers.models.utils import should_use_fuse_rope -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \ + should_use_compresskv, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal -from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache +from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, DynamicCompressCache from ipex_llm.utils.common import invalidInputError from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP @@ -117,11 +118,16 @@ def qwen2_model_forward( and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, self.config.num_attention_heads//self.config.num_key_value_heads) ) + use_compress_kv = should_use_compresskv(inputs) if use_cache: if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache): + elif use_compress_kv and not isinstance(past_key_values, + DynamicCompressCache): + past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) + if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, + DynamicNormalCache): past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) # ipex-llm changes end @@ -394,6 +400,9 @@ def qwen2_attention_forward( bsz, q_len, _ = hidden_states.size() device = hidden_states.device + # [CompressKV] + use_compresskv = should_use_compresskv(hidden_states) + if hasattr(self, 'qkv_proj') and self.qkv_proj is not None: qkv = self.qkv_proj(hidden_states) qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) @@ -427,8 +436,16 @@ def qwen2_attention_forward( cos, sin, position_ids) if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, None) + # [CompressKV] + if use_compresskv: + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, + query_states, attention_mask, self.num_key_value_groups, + self.config, enough_kv_room, 256) + else: + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, None) attn_weights = None if query_states.device.type == "cpu":