Qwen support compress kv (#11680)
* Qwen support compress kv * fix style * fix
This commit is contained in:
parent
9b36877897
commit
670ad887fc
4 changed files with 56 additions and 18 deletions
|
|
@ -259,7 +259,28 @@ class DynamicCompressCache(DynamicCache):
|
|||
num_key_value_groups=num_key_value_groups)
|
||||
self.key_cache.append(key_states_compress)
|
||||
self.value_cache.append(value_states_compress)
|
||||
return key_states, value_states
|
||||
|
||||
k_cache_compressed, v_cache_compressed = init_kv_cache(
|
||||
bsz, num_heads, head_dim,
|
||||
0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
key_states.dtype, key_states.device
|
||||
)
|
||||
k_cache_compressed, v_cache_compressed = append_kv_cache(
|
||||
k_cache_compressed, v_cache_compressed,
|
||||
key_states_compress, value_states_compress)
|
||||
self.key_cache[layer_idx] = k_cache_compressed
|
||||
self.value_cache[layer_idx] = v_cache_compressed
|
||||
|
||||
if key_states.stride(2) != head_dim:
|
||||
k_cache, v_cache = init_kv_cache(
|
||||
bsz, num_heads, head_dim,
|
||||
0, key_states.size(2),
|
||||
key_states.dtype, key_states.device
|
||||
)
|
||||
k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
|
||||
return k_cache, v_cache
|
||||
else:
|
||||
return key_states, value_states
|
||||
else:
|
||||
cache_k = self.key_cache[layer_idx]
|
||||
cache_v = self.value_cache[layer_idx]
|
||||
|
|
|
|||
|
|
@ -1289,7 +1289,7 @@ def llama_attention_forward_4_41_original(
|
|||
# for flash attention
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
# [SnapKV]
|
||||
# [CompressKV]
|
||||
use_compresskv = should_use_compresskv(hidden_states)
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
|
|
@ -1324,7 +1324,7 @@ def llama_attention_forward_4_41_original(
|
|||
self.rotary_emb.base,)
|
||||
kv_seq_len += 1
|
||||
# update past_key_value's seem_tokens and kv caches.
|
||||
# [SnapKV]
|
||||
# [CompressKV]
|
||||
if use_compresskv:
|
||||
past_key_value.update_seen_tokens(self.layer_idx, q_len)
|
||||
kv_seq_len = past_key_value.get_seq_length()
|
||||
|
|
@ -1496,7 +1496,7 @@ def llama_attention_forward_4_41_original(
|
|||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import xe_addons
|
||||
if use_compresskv:
|
||||
# [SnapKV] set attention_mask = None
|
||||
# [CompressKV] set attention_mask = None
|
||||
new_attention_mask = None
|
||||
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
||||
new_attention_mask)
|
||||
|
|
@ -1833,7 +1833,7 @@ def llama_attention_forward_4_38_original(
|
|||
# for flash attention
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
# [SnapKV]
|
||||
# [CompressKV]
|
||||
use_compresskv = should_use_compresskv(hidden_states)
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
|
|
@ -1868,7 +1868,7 @@ def llama_attention_forward_4_38_original(
|
|||
self.rotary_emb.base,)
|
||||
kv_seq_len += 1
|
||||
# update past_key_value's seem_tokens and kv caches.
|
||||
# [SnapKV]
|
||||
# [CompressKV]
|
||||
if use_compresskv:
|
||||
past_key_value.update_seen_tokens(self.layer_idx, q_len)
|
||||
kv_seq_len = past_key_value.get_seq_length()
|
||||
|
|
@ -2039,7 +2039,7 @@ def llama_attention_forward_4_38_original(
|
|||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import xe_addons
|
||||
if use_compresskv:
|
||||
# [SnapKV] set attention_mask = None
|
||||
# [CompressKV] set attention_mask = None
|
||||
new_attention_mask = None
|
||||
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
||||
new_attention_mask)
|
||||
|
|
|
|||
|
|
@ -897,7 +897,7 @@ def mistral_attention_forward_4_36_original(
|
|||
# for flash attention
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
# [SnapKV]
|
||||
# [CompressKV]
|
||||
use_compresskv = should_use_compresskv(hidden_states)
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
|
|
@ -930,7 +930,7 @@ def mistral_attention_forward_4_36_original(
|
|||
kv_seq_len += 1
|
||||
|
||||
# update past_key_value's seem_tokens and kv caches.
|
||||
# [SnapKV]
|
||||
# [CompressKV]
|
||||
if use_compresskv:
|
||||
past_key_value.update_seen_tokens(self.layer_idx, q_len)
|
||||
kv_seq_len = past_key_value.get_seq_length()
|
||||
|
|
@ -1055,7 +1055,7 @@ def mistral_attention_forward_4_36_original(
|
|||
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
# new fp16 sdp doesn't require repeat_kv
|
||||
import xe_addons
|
||||
# [SnapKV] set attention_mask = None
|
||||
# [CompressKV] set attention_mask = None
|
||||
if use_compresskv:
|
||||
attention_mask = None
|
||||
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
|
||||
|
|
@ -1142,7 +1142,7 @@ def mistral_attention_forward_4_39_original(
|
|||
# for flash attention
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
# [SnapKV]
|
||||
# [CompressKV]
|
||||
use_compresskv = should_use_compresskv(hidden_states)
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
|
|
@ -1175,7 +1175,7 @@ def mistral_attention_forward_4_39_original(
|
|||
kv_seq_len += 1
|
||||
|
||||
# update past_key_value's seem_tokens and kv caches.
|
||||
# [SnapKV]
|
||||
# [CompressKV]
|
||||
if use_compresskv:
|
||||
past_key_value.update_seen_tokens(self.layer_idx, q_len)
|
||||
kv_seq_len = past_key_value.get_seq_length()
|
||||
|
|
@ -1300,7 +1300,7 @@ def mistral_attention_forward_4_39_original(
|
|||
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
# new fp16 sdp doesn't require repeat_kv
|
||||
import xe_addons
|
||||
# [SnapKV] set attention_mask = None
|
||||
# [CompressKV] set attention_mask = None
|
||||
if use_compresskv:
|
||||
attention_mask = None
|
||||
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
|
||||
|
|
|
|||
|
|
@ -47,9 +47,10 @@ from torch.nn.functional import scaled_dot_product_attention as sdpa
|
|||
|
||||
from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
|
||||
should_use_compresskv, is_enough_kv_cache_room_4_36
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, DynamicCompressCache
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
|
||||
|
|
@ -117,11 +118,16 @@ def qwen2_model_forward(
|
|||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||
)
|
||||
use_compress_kv = should_use_compresskv(inputs)
|
||||
|
||||
if use_cache:
|
||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
|
||||
elif use_compress_kv and not isinstance(past_key_values,
|
||||
DynamicCompressCache):
|
||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||
DynamicNormalCache):
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
# ipex-llm changes end
|
||||
|
|
@ -394,6 +400,9 @@ def qwen2_attention_forward(
|
|||
bsz, q_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
|
||||
# [CompressKV]
|
||||
use_compresskv = should_use_compresskv(hidden_states)
|
||||
|
||||
if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
||||
|
|
@ -427,8 +436,16 @@ def qwen2_attention_forward(
|
|||
cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, None)
|
||||
# [CompressKV]
|
||||
if use_compresskv:
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx,
|
||||
query_states, attention_mask, self.num_key_value_groups,
|
||||
self.config, enough_kv_room, 256)
|
||||
else:
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, None)
|
||||
|
||||
attn_weights = None
|
||||
if query_states.device.type == "cpu":
|
||||
|
|
|
|||
Loading…
Reference in a new issue