Support compress kv (#11642)

* mistral snapkv

* update

* mtl update

* update

* update

* update

* add comments

* style fix

* fix style

* support llama

* llama use compress kv

* support mistral 4.40

* fix style

* support diff transformers versions

* move snapkv util to kv

* fix style

* meet comments & small fix

* revert all in one

* fix indent

---------

Co-authored-by: leonardozcm <leonardo1997zcm@gmail.com>
This commit is contained in:
Yina Chen 2024-07-26 11:02:00 +03:00 committed by GitHub
parent 6bcdc6cc8f
commit fc7f8feb83
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 422 additions and 152 deletions

View file

@ -1443,14 +1443,14 @@ def _optimize_post(model, lightweight_bmm=False):
if version.parse(trans_version) >= version.parse("4.36.0"):
from ipex_llm.transformers.models.mistral import mistral_model_forward_4_36
if version.parse(trans_version) >= version.parse("4.39.0"):
from ipex_llm.transformers.models.mistral import mistral_attention_forward_4_39
from ipex_llm.transformers.models.mistral import \
mistral_attention_forward_4_39
convert_forward(model,
module.MistralAttention,
mistral_attention_forward_4_39
)
else:
from ipex_llm.transformers.models.mistral import mistral_attention_forward_4_36
convert_forward(model,
module.MistralAttention,
mistral_attention_forward_4_36

View file

@ -16,13 +16,17 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
from .models.utils import (
init_fp8_kv_cache, append_fp8_kv_cache,
init_kv_cache, append_kv_cache
init_kv_cache, append_kv_cache, extend_kv_cache
)
from typing import Optional, Dict, Tuple, Any
from transformers.cache_utils import DynamicCache
from ipex_llm.utils.common.log4Error import invalidInputError
class DynamicFp8Cache(DynamicCache):
@ -116,3 +120,178 @@ class DynamicNormalCache(DynamicCache):
self.value_cache[layer_idx] = v_cache
return self.key_cache[layer_idx], self.value_cache[layer_idx]
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim)
to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# This function is adapted from
# https://github.com/FasterDecoding/SnapKV/blob/main/snapkv/monkeypatch/snapkv_utils.py
def compress_kv(attn_config, key_states, query_states, value_states, attention_mask,
num_key_value_groups):
# check if prefix phase
invalidInputError(key_states.shape[-2] == query_states.shape[-2], "kv shape mismatch.")
if not hasattr(attn_config, 'window_size'):
attn_config.window_size = 32
if not hasattr(attn_config, 'max_capacity_prompt'):
attn_config.max_capacity_prompt = 512
if not hasattr(attn_config, 'kernel_size'):
attn_config.kernel_size = 5
if not hasattr(attn_config, 'pooling'):
attn_config.pooling = 'avgpool'
bsz, num_heads, q_len, head_dim = query_states.shape
if q_len < attn_config.max_capacity_prompt:
return key_states, value_states
else:
key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device)
attn_weights = torch.matmul(query_states[..., -attn_config.window_size:, :],
key_states_expand.transpose(2, 3)) / math.sqrt(head_dim)
mask = torch.full((attn_config.window_size, attn_config.window_size),
torch.finfo(attn_weights.dtype).min,
device=attn_weights.device)
mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(attn_weights.device)
attention_mask = mask[None, None, :, :]
attn_weights[:, :, -attn_config.window_size:, -attn_config.window_size:] += attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_weights_sum = attn_weights[:, :, -attn_config.window_size:,
:-attn_config.window_size].sum(dim=-2)
if attn_config.pooling == 'avgpool':
if num_key_value_groups > 1:
attn_cache = F.avg_pool2d(attn_weights_sum, kernel_size=(num_key_value_groups,
attn_config.kernel_size),
padding=(0, attn_config.kernel_size//2),
stride=(num_key_value_groups, 1))
else:
attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size,
padding=attn_config.kernel_size//2, stride=1)
elif attn_config.pooling == 'maxpool':
if num_key_value_groups > 1:
attn_cache = F.max_pool2d(attn_weights_sum,
kernel_size=(num_key_value_groups,
attn_config.kernel_size),
padding=(0, attn_config.kernel_size//2),
stride=(num_key_value_groups, 1))
else:
attn_cache = F.max_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size,
padding=attn_config.kernel_size//2, stride=1)
else:
invalidInputError(False, 'Pooling method not supported')
indices = attn_cache.topk(attn_config.max_capacity_prompt - attn_config.window_size,
dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
k_past_compress = key_states[:, :, :-attn_config.window_size, :].gather(dim=2,
index=indices)
v_past_compress = value_states[:, :, :-attn_config.window_size, :].gather(dim=2,
index=indices)
k_cur = key_states[:, :, -attn_config.window_size:, :]
v_cur = value_states[:, :, -attn_config.window_size:, :]
key_states = torch.cat([k_past_compress, k_cur], dim=2)
value_states = torch.cat([v_past_compress, v_cur], dim=2)
return key_states, value_states
class DynamicCompressCache(DynamicCache):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.real_kv_len = 0
def update_seen_tokens(self, layer_idx, q_len):
if layer_idx == 0:
if hasattr(self, "_seen_tokens"):
# 4.39 uses `_seen_tokens`
self._seen_tokens += q_len
else:
# 4.37 uses `seen_tokens`
self.seen_tokens += q_len
self.real_kv_len += q_len
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
query_states: torch.Tensor,
attention_mask: torch.Tensor,
num_key_value_groups: int,
attn_config: Dict[str, Any],
enough_kv_room: bool,
KV_CACHE_ALLOC_BLOCK_LENGTH: int,
cache_kwargs: Optional[Dict[str, Any]]=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
bsz, num_heads, seq_len, head_dim = key_states.shape
if layer_idx == 0:
if hasattr(self, "_seen_tokens"):
# 4.39 uses `_seen_tokens`
self._seen_tokens += seq_len
else:
# 4.37 uses `seen_tokens`
self.seen_tokens += seq_len
self.real_kv_len += seq_len
# Update the cache
if len(self.key_cache) <= layer_idx:
# First token, compress kv cache
key_states_compress, value_states_compress = compress_kv(
attn_config=attn_config,
key_states=key_states,
query_states=query_states,
value_states=value_states,
attention_mask=attention_mask,
num_key_value_groups=num_key_value_groups)
self.key_cache.append(key_states_compress)
self.value_cache.append(value_states_compress)
return key_states, value_states
else:
cache_k = self.key_cache[layer_idx]
cache_v = self.value_cache[layer_idx]
if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
num_heads, # Support GQA
head_dim,
cache_k.size(2),
cache_k.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=query_states.device)
new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v
key_states, value_states = append_kv_cache(cache_k,
cache_v,
key_states,
value_states)
# update past_key_value
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer
index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
return self.real_kv_len

View file

@ -42,7 +42,7 @@ import torch.nn.functional as F
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import SILU
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache
restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
@ -113,12 +113,18 @@ def llama_model_forward_4_36(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
from ipex_llm.transformers.kv import DynamicFp8Cache
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if use_cache:
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input):
# if use quantize kv, compress kv will be ignored now
if not isinstance(past_key_values, DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache(
past_key_values)
return llama_model_forward_4_36_internal(
self=self,
input_ids=input_ids,
@ -146,12 +152,18 @@ def llama_model_forward_4_38(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
from ipex_llm.transformers.kv import DynamicFp8Cache
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if use_cache:
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input):
# if use quantize kv, compress kv will be ignored now
if not isinstance(past_key_values, DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache(
past_key_values)
return llama_model_forward_4_38_internal(
self=self,
input_ids=input_ids,
@ -180,12 +192,18 @@ def llama_model_forward_4_41(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
from ipex_llm.transformers.kv import DynamicFp8Cache
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if use_cache:
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input):
# if use quantize kv, compress kv will be ignored now
if not isinstance(past_key_values, DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache(
past_key_values)
return llama_model_forward_4_41_internal(
self=self,
input_ids=input_ids,
@ -1267,6 +1285,9 @@ def llama_attention_forward_4_41_original(
# for flash attention
original_dtype = hidden_states.dtype
# [SnapKV]
use_compresskv = should_use_compresskv(hidden_states)
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1
@ -1299,7 +1320,11 @@ def llama_attention_forward_4_41_original(
self.rotary_emb.base,)
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
if self.layer_idx == 0:
# [SnapKV]
if use_compresskv:
past_key_value.update_seen_tokens(self.layer_idx, q_len)
kv_seq_len = past_key_value.get_seq_length()
elif self.layer_idx == 0:
past_key_value._seen_tokens = kv_seq_len
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
@ -1404,46 +1429,51 @@ def llama_attention_forward_4_41_original(
cos, sin, position_ids, "llama")
if past_key_value is not None:
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value._seen_tokens += key_states.shape[-2]
# reuse k, v, self_attention
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
if use_compresskv:
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value._seen_tokens += key_states.shape[-2]
if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
# reuse k, v, self_attention
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v
if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
key_states, value_states = append_kv_cache(cache_k,
cache_v,
key_states,
value_states)
new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v
# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
key_states, value_states = append_kv_cache(cache_k,
cache_v,
key_states,
value_states)
# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
if cache_position is not None:
new_attention_mask = attention_mask[:, :, :, 0:kv_seq_len]
else:
new_attention_mask = attention_mask
@ -1461,6 +1491,9 @@ def llama_attention_forward_4_41_original(
elif not self.training and not hidden_states.requires_grad and \
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import xe_addons
if use_compresskv:
# [SnapKV] set attention_mask = None
new_attention_mask = None
attn_output = xe_addons.sdp(query_states, key_states, value_states,
new_attention_mask)
attn_output = attn_output.view(query_states.shape)
@ -1791,6 +1824,9 @@ def llama_attention_forward_4_38_original(
# for flash attention
original_dtype = hidden_states.dtype
# [SnapKV]
use_compresskv = should_use_compresskv(hidden_states)
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1
@ -1823,11 +1859,14 @@ def llama_attention_forward_4_38_original(
self.rotary_emb.base,)
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
if self.layer_idx == 0:
# [SnapKV]
if use_compresskv:
past_key_value.update_seen_tokens(self.layer_idx, q_len)
kv_seq_len = past_key_value.get_seq_length()
elif self.layer_idx == 0:
past_key_value.seen_tokens = kv_seq_len
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
else:
if self.config.pretraining_tp > 1:
key_value_slicing = ((self.num_key_value_heads * self.head_dim) //
@ -1928,42 +1967,48 @@ def llama_attention_forward_4_38_original(
cos, sin, position_ids, "llama")
if past_key_value is not None:
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value.seen_tokens += key_states.shape[-2]
# reuse k, v, self_attention
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
if use_compresskv:
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value.seen_tokens += key_states.shape[-2]
if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
# reuse k, v, self_attention
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v
if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
key_states, value_states = append_kv_cache(cache_k,
cache_v,
key_states,
value_states)
new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v
# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
key_states, value_states = append_kv_cache(cache_k,
cache_v,
key_states,
value_states)
# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
if cache_position is not None:
new_attention_mask = attention_mask[:, :, kv_seq_len - q_len:kv_seq_len, 0:kv_seq_len]
@ -1984,6 +2029,9 @@ def llama_attention_forward_4_38_original(
elif not self.training and not hidden_states.requires_grad and \
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import xe_addons
if use_compresskv:
# [SnapKV] set attention_mask = None
new_attention_mask = None
attn_output = xe_addons.sdp(query_states, key_states, value_states,
new_attention_mask)
attn_output = attn_output.view(query_states.shape)
@ -2515,11 +2563,11 @@ def llama_model_forward_4_41_internal(
all_hidden_states += (hidden_states,)
next_cache = None
from ipex_llm.transformers.kv import DynamicFp8Cache
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if not isinstance(next_decoder_cache, DynamicFp8Cache)
if not isinstance(next_decoder_cache, (DynamicFp8Cache, DynamicCompressCache))
else next_decoder_cache
)
@ -2645,11 +2693,11 @@ def llama_model_forward_4_38_internal(
all_hidden_states += (hidden_states,)
next_cache = None
from ipex_llm.transformers.kv import DynamicFp8Cache
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if not isinstance(next_decoder_cache, DynamicFp8Cache)
if not isinstance(next_decoder_cache, (DynamicFp8Cache, DynamicCompressCache))
else next_decoder_cache
)
if not return_dict:

View file

@ -46,7 +46,7 @@ from transformers.models.mistral.modeling_mistral import MistralModel
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache
restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
apply_rotary_pos_emb_no_cache_xpu
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
@ -202,11 +202,17 @@ def mistral_model_forward_4_36(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
from ipex_llm.transformers.kv import DynamicFp8Cache
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if use_cache:
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input_ids):
# if use quantize kv, compress kv will be ignored now
if not isinstance(past_key_values, DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache(
past_key_values)
return MistralModel.forward(
self=self,
input_ids=input_ids,
@ -890,6 +896,9 @@ def mistral_attention_forward_4_36_original(
# for flash attention
original_dtype = hidden_states.dtype
# [SnapKV]
use_compresskv = should_use_compresskv(hidden_states)
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
@ -920,7 +929,11 @@ def mistral_attention_forward_4_36_original(
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
if self.layer_idx == 0:
# [SnapKV]
if use_compresskv:
past_key_value.update_seen_tokens(self.layer_idx, q_len)
kv_seq_len = past_key_value.get_seq_length()
elif self.layer_idx == 0:
past_key_value.seen_tokens = kv_seq_len
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
@ -975,40 +988,46 @@ def mistral_attention_forward_4_36_original(
cos, sin, position_ids, "mistral")
if past_key_value is not None:
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value.seen_tokens += key_states.shape[-2]
# reuse k, v, self_attention
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
if use_compresskv:
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value.seen_tokens += key_states.shape[-2]
if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
# reuse k, v, self_attention
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v
if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
key_states, value_states = append_kv_cache(cache_k, cache_v,
key_states, value_states)
new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v
# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
key_states, value_states = append_kv_cache(cache_k, cache_v,
key_states, value_states)
# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
if not self.training and not hidden_states.requires_grad:
fsdp_flag = use_flash_attention(query_states, key_states)
@ -1035,6 +1054,9 @@ def mistral_attention_forward_4_36_original(
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
# new fp16 sdp doesn't require repeat_kv
import xe_addons
# [SnapKV] set attention_mask = None
if use_compresskv:
attention_mask = None
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
@ -1119,6 +1141,9 @@ def mistral_attention_forward_4_39_original(
# for flash attention
original_dtype = hidden_states.dtype
# [SnapKV]
use_compresskv = should_use_compresskv(hidden_states)
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
@ -1149,11 +1174,14 @@ def mistral_attention_forward_4_39_original(
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
if self.layer_idx == 0:
# [SnapKV]
if use_compresskv:
past_key_value.update_seen_tokens(self.layer_idx, q_len)
kv_seq_len = past_key_value.get_seq_length()
elif self.layer_idx == 0:
past_key_value._seen_tokens = kv_seq_len
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
else:
if should_use_xetla_mm_qkv(self, device):
if not hasattr(self, "qkv_proj_qweight"):
@ -1204,40 +1232,47 @@ def mistral_attention_forward_4_39_original(
cos, sin, position_ids, "mistral")
if past_key_value is not None:
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value._seen_tokens += key_states.shape[-2]
# reuse k, v, self_attention
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
if use_compresskv:
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value._seen_tokens += key_states.shape[-2]
if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
# reuse k, v, self_attention
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v
if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
key_states, value_states = append_kv_cache(cache_k, cache_v,
key_states, value_states)
new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v
# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
key_states, value_states = append_kv_cache(cache_k, cache_v,
key_states,
value_states)
# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
if not self.training and not hidden_states.requires_grad:
fsdp_flag = use_flash_attention(query_states, key_states)
@ -1264,6 +1299,9 @@ def mistral_attention_forward_4_39_original(
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
# new fp16 sdp doesn't require repeat_kv
import xe_addons
# [SnapKV] set attention_mask = None
if use_compresskv:
attention_mask = None
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None

View file

@ -479,3 +479,8 @@ def update_past_key_value(past_key_value, key_states, value_states,
v_cache = new_v_cache
key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states)
return key_states, value_states
def should_use_compresskv(x: torch.Tensor):
use_compress_kv = os.environ.get("IPEX_LLM_COMPRESS_KV_CACHE", None)
return x.device.type == 'xpu' and use_compress_kv == "1"