Qwen support compress kv (#11680)

* Qwen support compress kv

* fix style

* fix
This commit is contained in:
Yina Chen 2024-07-30 06:16:42 +03:00 committed by GitHub
parent 9b36877897
commit 670ad887fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 56 additions and 18 deletions

View file

@ -259,7 +259,28 @@ class DynamicCompressCache(DynamicCache):
num_key_value_groups=num_key_value_groups)
self.key_cache.append(key_states_compress)
self.value_cache.append(value_states_compress)
return key_states, value_states
k_cache_compressed, v_cache_compressed = init_kv_cache(
bsz, num_heads, head_dim,
0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
key_states.dtype, key_states.device
)
k_cache_compressed, v_cache_compressed = append_kv_cache(
k_cache_compressed, v_cache_compressed,
key_states_compress, value_states_compress)
self.key_cache[layer_idx] = k_cache_compressed
self.value_cache[layer_idx] = v_cache_compressed
if key_states.stride(2) != head_dim:
k_cache, v_cache = init_kv_cache(
bsz, num_heads, head_dim,
0, key_states.size(2),
key_states.dtype, key_states.device
)
k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
return k_cache, v_cache
else:
return key_states, value_states
else:
cache_k = self.key_cache[layer_idx]
cache_v = self.value_cache[layer_idx]

View file

@ -1289,7 +1289,7 @@ def llama_attention_forward_4_41_original(
# for flash attention
original_dtype = hidden_states.dtype
# [SnapKV]
# [CompressKV]
use_compresskv = should_use_compresskv(hidden_states)
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
@ -1324,7 +1324,7 @@ def llama_attention_forward_4_41_original(
self.rotary_emb.base,)
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
# [SnapKV]
# [CompressKV]
if use_compresskv:
past_key_value.update_seen_tokens(self.layer_idx, q_len)
kv_seq_len = past_key_value.get_seq_length()
@ -1496,7 +1496,7 @@ def llama_attention_forward_4_41_original(
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import xe_addons
if use_compresskv:
# [SnapKV] set attention_mask = None
# [CompressKV] set attention_mask = None
new_attention_mask = None
attn_output = xe_addons.sdp(query_states, key_states, value_states,
new_attention_mask)
@ -1833,7 +1833,7 @@ def llama_attention_forward_4_38_original(
# for flash attention
original_dtype = hidden_states.dtype
# [SnapKV]
# [CompressKV]
use_compresskv = should_use_compresskv(hidden_states)
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
@ -1868,7 +1868,7 @@ def llama_attention_forward_4_38_original(
self.rotary_emb.base,)
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
# [SnapKV]
# [CompressKV]
if use_compresskv:
past_key_value.update_seen_tokens(self.layer_idx, q_len)
kv_seq_len = past_key_value.get_seq_length()
@ -2039,7 +2039,7 @@ def llama_attention_forward_4_38_original(
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import xe_addons
if use_compresskv:
# [SnapKV] set attention_mask = None
# [CompressKV] set attention_mask = None
new_attention_mask = None
attn_output = xe_addons.sdp(query_states, key_states, value_states,
new_attention_mask)

View file

@ -897,7 +897,7 @@ def mistral_attention_forward_4_36_original(
# for flash attention
original_dtype = hidden_states.dtype
# [SnapKV]
# [CompressKV]
use_compresskv = should_use_compresskv(hidden_states)
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
@ -930,7 +930,7 @@ def mistral_attention_forward_4_36_original(
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
# [SnapKV]
# [CompressKV]
if use_compresskv:
past_key_value.update_seen_tokens(self.layer_idx, q_len)
kv_seq_len = past_key_value.get_seq_length()
@ -1055,7 +1055,7 @@ def mistral_attention_forward_4_36_original(
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
# new fp16 sdp doesn't require repeat_kv
import xe_addons
# [SnapKV] set attention_mask = None
# [CompressKV] set attention_mask = None
if use_compresskv:
attention_mask = None
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
@ -1142,7 +1142,7 @@ def mistral_attention_forward_4_39_original(
# for flash attention
original_dtype = hidden_states.dtype
# [SnapKV]
# [CompressKV]
use_compresskv = should_use_compresskv(hidden_states)
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
@ -1175,7 +1175,7 @@ def mistral_attention_forward_4_39_original(
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
# [SnapKV]
# [CompressKV]
if use_compresskv:
past_key_value.update_seen_tokens(self.layer_idx, q_len)
kv_seq_len = past_key_value.get_seq_length()
@ -1300,7 +1300,7 @@ def mistral_attention_forward_4_39_original(
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
# new fp16 sdp doesn't require repeat_kv
import xe_addons
# [SnapKV] set attention_mask = None
# [CompressKV] set attention_mask = None
if use_compresskv:
attention_mask = None
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)

View file

@ -47,9 +47,10 @@ from torch.nn.functional import scaled_dot_product_attention as sdpa
from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
should_use_compresskv, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, DynamicCompressCache
from ipex_llm.utils.common import invalidInputError
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
@ -117,11 +118,16 @@ def qwen2_model_forward(
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads//self.config.num_key_value_heads)
)
use_compress_kv = should_use_compresskv(inputs)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
elif use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
# ipex-llm changes end
@ -394,6 +400,9 @@ def qwen2_attention_forward(
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
# [CompressKV]
use_compresskv = should_use_compresskv(hidden_states)
if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
@ -427,8 +436,16 @@ def qwen2_attention_forward(
cos, sin, position_ids)
if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)
# [CompressKV]
if use_compresskv:
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, 256)
else:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)
attn_weights = None
if query_states.device.type == "cpu":