diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 6791e138..936fab56 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1443,14 +1443,14 @@ def _optimize_post(model, lightweight_bmm=False): if version.parse(trans_version) >= version.parse("4.36.0"): from ipex_llm.transformers.models.mistral import mistral_model_forward_4_36 if version.parse(trans_version) >= version.parse("4.39.0"): - from ipex_llm.transformers.models.mistral import mistral_attention_forward_4_39 + from ipex_llm.transformers.models.mistral import \ + mistral_attention_forward_4_39 convert_forward(model, module.MistralAttention, mistral_attention_forward_4_39 ) else: from ipex_llm.transformers.models.mistral import mistral_attention_forward_4_36 - convert_forward(model, module.MistralAttention, mistral_attention_forward_4_36 diff --git a/python/llm/src/ipex_llm/transformers/kv.py b/python/llm/src/ipex_llm/transformers/kv.py index d4e9d9d4..c6e79dc6 100644 --- a/python/llm/src/ipex_llm/transformers/kv.py +++ b/python/llm/src/ipex_llm/transformers/kv.py @@ -16,13 +16,17 @@ import torch +import torch.nn.functional as F +import torch.nn as nn +import math from .models.utils import ( init_fp8_kv_cache, append_fp8_kv_cache, - init_kv_cache, append_kv_cache + init_kv_cache, append_kv_cache, extend_kv_cache ) from typing import Optional, Dict, Tuple, Any from transformers.cache_utils import DynamicCache +from ipex_llm.utils.common.log4Error import invalidInputError class DynamicFp8Cache(DynamicCache): @@ -116,3 +120,178 @@ class DynamicNormalCache(DynamicCache): self.value_cache[layer_idx] = v_cache return self.key_cache[layer_idx], self.value_cache[layer_idx] + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) + to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, + n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# This function is adapted from +# https://github.com/FasterDecoding/SnapKV/blob/main/snapkv/monkeypatch/snapkv_utils.py +def compress_kv(attn_config, key_states, query_states, value_states, attention_mask, + num_key_value_groups): + # check if prefix phase + invalidInputError(key_states.shape[-2] == query_states.shape[-2], "kv shape mismatch.") + if not hasattr(attn_config, 'window_size'): + attn_config.window_size = 32 + if not hasattr(attn_config, 'max_capacity_prompt'): + attn_config.max_capacity_prompt = 512 + if not hasattr(attn_config, 'kernel_size'): + attn_config.kernel_size = 5 + if not hasattr(attn_config, 'pooling'): + attn_config.pooling = 'avgpool' + bsz, num_heads, q_len, head_dim = query_states.shape + if q_len < attn_config.max_capacity_prompt: + return key_states, value_states + else: + key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device) + attn_weights = torch.matmul(query_states[..., -attn_config.window_size:, :], + key_states_expand.transpose(2, 3)) / math.sqrt(head_dim) + mask = torch.full((attn_config.window_size, attn_config.window_size), + torch.finfo(attn_weights.dtype).min, + device=attn_weights.device) + mask_cond = torch.arange(mask.size(-1), device=attn_weights.device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(attn_weights.device) + attention_mask = mask[None, None, :, :] + + attn_weights[:, :, -attn_config.window_size:, -attn_config.window_size:] += attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_weights_sum = attn_weights[:, :, -attn_config.window_size:, + :-attn_config.window_size].sum(dim=-2) + if attn_config.pooling == 'avgpool': + if num_key_value_groups > 1: + attn_cache = F.avg_pool2d(attn_weights_sum, kernel_size=(num_key_value_groups, + attn_config.kernel_size), + padding=(0, attn_config.kernel_size//2), + stride=(num_key_value_groups, 1)) + else: + attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size, + padding=attn_config.kernel_size//2, stride=1) + elif attn_config.pooling == 'maxpool': + if num_key_value_groups > 1: + attn_cache = F.max_pool2d(attn_weights_sum, + kernel_size=(num_key_value_groups, + attn_config.kernel_size), + padding=(0, attn_config.kernel_size//2), + stride=(num_key_value_groups, 1)) + else: + attn_cache = F.max_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size, + padding=attn_config.kernel_size//2, stride=1) + else: + invalidInputError(False, 'Pooling method not supported') + indices = attn_cache.topk(attn_config.max_capacity_prompt - attn_config.window_size, + dim=-1).indices + indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim) + k_past_compress = key_states[:, :, :-attn_config.window_size, :].gather(dim=2, + index=indices) + v_past_compress = value_states[:, :, :-attn_config.window_size, :].gather(dim=2, + index=indices) + k_cur = key_states[:, :, -attn_config.window_size:, :] + v_cur = value_states[:, :, -attn_config.window_size:, :] + key_states = torch.cat([k_past_compress, k_cur], dim=2) + value_states = torch.cat([v_past_compress, v_cur], dim=2) + return key_states, value_states + + +class DynamicCompressCache(DynamicCache): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.real_kv_len = 0 + + def update_seen_tokens(self, layer_idx, q_len): + if layer_idx == 0: + if hasattr(self, "_seen_tokens"): + # 4.39 uses `_seen_tokens` + self._seen_tokens += q_len + else: + # 4.37 uses `seen_tokens` + self.seen_tokens += q_len + self.real_kv_len += q_len + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + query_states: torch.Tensor, + attention_mask: torch.Tensor, + num_key_value_groups: int, + attn_config: Dict[str, Any], + enough_kv_room: bool, + KV_CACHE_ALLOC_BLOCK_LENGTH: int, + cache_kwargs: Optional[Dict[str, Any]]=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + bsz, num_heads, seq_len, head_dim = key_states.shape + + if layer_idx == 0: + if hasattr(self, "_seen_tokens"): + # 4.39 uses `_seen_tokens` + self._seen_tokens += seq_len + else: + # 4.37 uses `seen_tokens` + self.seen_tokens += seq_len + self.real_kv_len += seq_len + + # Update the cache + if len(self.key_cache) <= layer_idx: + # First token, compress kv cache + key_states_compress, value_states_compress = compress_kv( + attn_config=attn_config, + key_states=key_states, + query_states=query_states, + value_states=value_states, + attention_mask=attention_mask, + num_key_value_groups=num_key_value_groups) + self.key_cache.append(key_states_compress) + self.value_cache.append(value_states_compress) + return key_states, value_states + else: + cache_k = self.key_cache[layer_idx] + cache_v = self.value_cache[layer_idx] + if not enough_kv_room: + # allocate new + new_c_k, new_c_v = extend_kv_cache(bsz, + num_heads, # Support GQA + head_dim, + cache_k.size(2), + cache_k.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=query_states.device) + + new_c_k[:] = cache_k + new_c_v[:] = cache_v + cache_k = new_c_k + cache_v = new_c_v + + key_states, value_states = append_kv_cache(cache_k, + cache_v, + key_states, + value_states) + + # update past_key_value + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer + index can be optionally passed.""" + if len(self.key_cache) <= layer_idx: + return 0 + return self.real_kv_len diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 11425b39..ae9b7812 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -42,7 +42,7 @@ import torch.nn.functional as F from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import SILU from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ - restore_fp8_kv_cache, use_quantize_kv_cache + restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu @@ -113,12 +113,18 @@ def llama_model_forward_4_36( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - from ipex_llm.transformers.kv import DynamicFp8Cache + from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache use_cache = use_cache if use_cache is not None else self.config.use_cache input = input_ids if input_ids is not None else inputs_embeds - if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input): - if not isinstance(past_key_values, DynamicFp8Cache): - past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + if use_cache: + if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input): + if not isinstance(past_key_values, DynamicFp8Cache): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + elif should_use_compresskv(input): + # if use quantize kv, compress kv will be ignored now + if not isinstance(past_key_values, DynamicCompressCache): + past_key_values = DynamicCompressCache.from_legacy_cache( + past_key_values) return llama_model_forward_4_36_internal( self=self, input_ids=input_ids, @@ -146,12 +152,18 @@ def llama_model_forward_4_38( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - from ipex_llm.transformers.kv import DynamicFp8Cache + from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache use_cache = use_cache if use_cache is not None else self.config.use_cache input = input_ids if input_ids is not None else inputs_embeds - if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input): - if not isinstance(past_key_values, DynamicFp8Cache): - past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + if use_cache: + if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input): + if not isinstance(past_key_values, DynamicFp8Cache): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + elif should_use_compresskv(input): + # if use quantize kv, compress kv will be ignored now + if not isinstance(past_key_values, DynamicCompressCache): + past_key_values = DynamicCompressCache.from_legacy_cache( + past_key_values) return llama_model_forward_4_38_internal( self=self, input_ids=input_ids, @@ -180,12 +192,18 @@ def llama_model_forward_4_41( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - from ipex_llm.transformers.kv import DynamicFp8Cache + from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache use_cache = use_cache if use_cache is not None else self.config.use_cache input = input_ids if input_ids is not None else inputs_embeds - if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input): - if not isinstance(past_key_values, DynamicFp8Cache): - past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + if use_cache: + if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input): + if not isinstance(past_key_values, DynamicFp8Cache): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + elif should_use_compresskv(input): + # if use quantize kv, compress kv will be ignored now + if not isinstance(past_key_values, DynamicCompressCache): + past_key_values = DynamicCompressCache.from_legacy_cache( + past_key_values) return llama_model_forward_4_41_internal( self=self, input_ids=input_ids, @@ -1267,6 +1285,9 @@ def llama_attention_forward_4_41_original( # for flash attention original_dtype = hidden_states.dtype + # [SnapKV] + use_compresskv = should_use_compresskv(hidden_states) + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) no_tp = not self.config.pretraining_tp > 1 @@ -1299,7 +1320,11 @@ def llama_attention_forward_4_41_original( self.rotary_emb.base,) kv_seq_len += 1 # update past_key_value's seem_tokens and kv caches. - if self.layer_idx == 0: + # [SnapKV] + if use_compresskv: + past_key_value.update_seen_tokens(self.layer_idx, q_len) + kv_seq_len = past_key_value.get_seq_length() + elif self.layer_idx == 0: past_key_value._seen_tokens = kv_seq_len past_key_value.key_cache[self.layer_idx] = key_states past_key_value.value_cache[self.layer_idx] = value_states @@ -1404,46 +1429,51 @@ def llama_attention_forward_4_41_original( cos, sin, position_ids, "llama") if past_key_value is not None: - # update the number of seen tokens - if self.layer_idx == 0: - past_key_value._seen_tokens += key_states.shape[-2] - - # reuse k, v, self_attention - # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` - if len(past_key_value.key_cache) <= self.layer_idx: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) + if use_compresskv: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, + query_states, attention_mask, self.num_key_value_groups, + self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) else: - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] + # update the number of seen tokens + if self.layer_idx == 0: + past_key_value._seen_tokens += key_states.shape[-2] - if not enough_kv_room: - # allocate new - new_c_k, new_c_v = extend_kv_cache(bsz, - self.num_key_value_heads, # Support GQA - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) + # reuse k, v, self_attention + # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` + if len(past_key_value.key_cache) <= self.layer_idx: + past_key_value.key_cache.append(key_states) + past_key_value.value_cache.append(value_states) + else: + cache_k = past_key_value.key_cache[self.layer_idx] + cache_v = past_key_value.value_cache[self.layer_idx] - new_c_k[:] = cache_k - new_c_v[:] = cache_v - cache_k = new_c_k - cache_v = new_c_v + if not enough_kv_room: + # allocate new + new_c_k, new_c_v = extend_kv_cache(bsz, + self.num_key_value_heads, # Support GQA + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) - key_states, value_states = append_kv_cache(cache_k, - cache_v, - key_states, - value_states) + new_c_k[:] = cache_k + new_c_v[:] = cache_v + cache_k = new_c_k + cache_v = new_c_v - # update past_key_value - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states + key_states, value_states = append_kv_cache(cache_k, + cache_v, + key_states, + value_states) + + # update past_key_value + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states if cache_position is not None: new_attention_mask = attention_mask[:, :, :, 0:kv_seq_len] - else: new_attention_mask = attention_mask @@ -1461,6 +1491,9 @@ def llama_attention_forward_4_41_original( elif not self.training and not hidden_states.requires_grad and \ use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import xe_addons + if use_compresskv: + # [SnapKV] set attention_mask = None + new_attention_mask = None attn_output = xe_addons.sdp(query_states, key_states, value_states, new_attention_mask) attn_output = attn_output.view(query_states.shape) @@ -1791,6 +1824,9 @@ def llama_attention_forward_4_38_original( # for flash attention original_dtype = hidden_states.dtype + # [SnapKV] + use_compresskv = should_use_compresskv(hidden_states) + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) no_tp = not self.config.pretraining_tp > 1 @@ -1823,11 +1859,14 @@ def llama_attention_forward_4_38_original( self.rotary_emb.base,) kv_seq_len += 1 # update past_key_value's seem_tokens and kv caches. - if self.layer_idx == 0: + # [SnapKV] + if use_compresskv: + past_key_value.update_seen_tokens(self.layer_idx, q_len) + kv_seq_len = past_key_value.get_seq_length() + elif self.layer_idx == 0: past_key_value.seen_tokens = kv_seq_len past_key_value.key_cache[self.layer_idx] = key_states past_key_value.value_cache[self.layer_idx] = value_states - else: if self.config.pretraining_tp > 1: key_value_slicing = ((self.num_key_value_heads * self.head_dim) // @@ -1928,42 +1967,48 @@ def llama_attention_forward_4_38_original( cos, sin, position_ids, "llama") if past_key_value is not None: - # update the number of seen tokens - if self.layer_idx == 0: - past_key_value.seen_tokens += key_states.shape[-2] - - # reuse k, v, self_attention - # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` - if len(past_key_value.key_cache) <= self.layer_idx: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) + if use_compresskv: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, + query_states, attention_mask, self.num_key_value_groups, + self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) else: - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] + # update the number of seen tokens + if self.layer_idx == 0: + past_key_value.seen_tokens += key_states.shape[-2] - if not enough_kv_room: - # allocate new - new_c_k, new_c_v = extend_kv_cache(bsz, - self.num_key_value_heads, # Support GQA - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) + # reuse k, v, self_attention + # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` + if len(past_key_value.key_cache) <= self.layer_idx: + past_key_value.key_cache.append(key_states) + past_key_value.value_cache.append(value_states) + else: + cache_k = past_key_value.key_cache[self.layer_idx] + cache_v = past_key_value.value_cache[self.layer_idx] - new_c_k[:] = cache_k - new_c_v[:] = cache_v - cache_k = new_c_k - cache_v = new_c_v + if not enough_kv_room: + # allocate new + new_c_k, new_c_v = extend_kv_cache(bsz, + self.num_key_value_heads, # Support GQA + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) - key_states, value_states = append_kv_cache(cache_k, - cache_v, - key_states, - value_states) + new_c_k[:] = cache_k + new_c_v[:] = cache_v + cache_k = new_c_k + cache_v = new_c_v - # update past_key_value - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states + key_states, value_states = append_kv_cache(cache_k, + cache_v, + key_states, + value_states) + + # update past_key_value + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states if cache_position is not None: new_attention_mask = attention_mask[:, :, kv_seq_len - q_len:kv_seq_len, 0:kv_seq_len] @@ -1984,6 +2029,9 @@ def llama_attention_forward_4_38_original( elif not self.training and not hidden_states.requires_grad and \ use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import xe_addons + if use_compresskv: + # [SnapKV] set attention_mask = None + new_attention_mask = None attn_output = xe_addons.sdp(query_states, key_states, value_states, new_attention_mask) attn_output = attn_output.view(query_states.shape) @@ -2515,11 +2563,11 @@ def llama_model_forward_4_41_internal( all_hidden_states += (hidden_states,) next_cache = None - from ipex_llm.transformers.kv import DynamicFp8Cache + from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache if use_cache: next_cache = ( next_decoder_cache.to_legacy_cache() - if not isinstance(next_decoder_cache, DynamicFp8Cache) + if not isinstance(next_decoder_cache, (DynamicFp8Cache, DynamicCompressCache)) else next_decoder_cache ) @@ -2645,11 +2693,11 @@ def llama_model_forward_4_38_internal( all_hidden_states += (hidden_states,) next_cache = None - from ipex_llm.transformers.kv import DynamicFp8Cache + from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache if use_cache: next_cache = ( next_decoder_cache.to_legacy_cache() - if not isinstance(next_decoder_cache, DynamicFp8Cache) + if not isinstance(next_decoder_cache, (DynamicFp8Cache, DynamicCompressCache)) else next_decoder_cache ) if not return_dict: diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 1891f982..3f3aa174 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -46,7 +46,7 @@ from transformers.models.mistral.modeling_mistral import MistralModel from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ - restore_fp8_kv_cache, use_quantize_kv_cache + restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ apply_rotary_pos_emb_no_cache_xpu from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ @@ -202,11 +202,17 @@ def mistral_model_forward_4_36( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - from ipex_llm.transformers.kv import DynamicFp8Cache + from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache use_cache = use_cache if use_cache is not None else self.config.use_cache - if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids): - if not isinstance(past_key_values, DynamicFp8Cache): - past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + if use_cache: + if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids): + if not isinstance(past_key_values, DynamicFp8Cache): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + elif should_use_compresskv(input_ids): + # if use quantize kv, compress kv will be ignored now + if not isinstance(past_key_values, DynamicCompressCache): + past_key_values = DynamicCompressCache.from_legacy_cache( + past_key_values) return MistralModel.forward( self=self, input_ids=input_ids, @@ -890,6 +896,9 @@ def mistral_attention_forward_4_36_original( # for flash attention original_dtype = hidden_states.dtype + # [SnapKV] + use_compresskv = should_use_compresskv(hidden_states) + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) decoding_fast_path = use_decoding_fast_path(self.q_proj, @@ -920,7 +929,11 @@ def mistral_attention_forward_4_36_original( kv_seq_len += 1 # update past_key_value's seem_tokens and kv caches. - if self.layer_idx == 0: + # [SnapKV] + if use_compresskv: + past_key_value.update_seen_tokens(self.layer_idx, q_len) + kv_seq_len = past_key_value.get_seq_length() + elif self.layer_idx == 0: past_key_value.seen_tokens = kv_seq_len past_key_value.key_cache[self.layer_idx] = key_states past_key_value.value_cache[self.layer_idx] = value_states @@ -975,40 +988,46 @@ def mistral_attention_forward_4_36_original( cos, sin, position_ids, "mistral") if past_key_value is not None: - # update the number of seen tokens - if self.layer_idx == 0: - past_key_value.seen_tokens += key_states.shape[-2] - - # reuse k, v, self_attention - # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` - if len(past_key_value.key_cache) <= self.layer_idx: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) + if use_compresskv: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, + query_states, attention_mask, self.num_key_value_groups, + self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) else: - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] + # update the number of seen tokens + if self.layer_idx == 0: + past_key_value.seen_tokens += key_states.shape[-2] - if not enough_kv_room: - # allocate new - new_c_k, new_c_v = extend_kv_cache(bsz, - self.num_key_value_heads, # Support GQA - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) + # reuse k, v, self_attention + # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` + if len(past_key_value.key_cache) <= self.layer_idx: + past_key_value.key_cache.append(key_states) + past_key_value.value_cache.append(value_states) + else: + cache_k = past_key_value.key_cache[self.layer_idx] + cache_v = past_key_value.value_cache[self.layer_idx] - new_c_k[:] = cache_k - new_c_v[:] = cache_v - cache_k = new_c_k - cache_v = new_c_v + if not enough_kv_room: + # allocate new + new_c_k, new_c_v = extend_kv_cache(bsz, + self.num_key_value_heads, # Support GQA + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) - key_states, value_states = append_kv_cache(cache_k, cache_v, - key_states, value_states) + new_c_k[:] = cache_k + new_c_v[:] = cache_v + cache_k = new_c_k + cache_v = new_c_v - # update past_key_value - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states + key_states, value_states = append_kv_cache(cache_k, cache_v, + key_states, value_states) + + # update past_key_value + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states if not self.training and not hidden_states.requires_grad: fsdp_flag = use_flash_attention(query_states, key_states) @@ -1035,6 +1054,9 @@ def mistral_attention_forward_4_36_original( elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): # new fp16 sdp doesn't require repeat_kv import xe_addons + # [SnapKV] set attention_mask = None + if use_compresskv: + attention_mask = None attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None @@ -1119,6 +1141,9 @@ def mistral_attention_forward_4_39_original( # for flash attention original_dtype = hidden_states.dtype + # [SnapKV] + use_compresskv = should_use_compresskv(hidden_states) + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) decoding_fast_path = use_decoding_fast_path(self.q_proj, @@ -1149,11 +1174,14 @@ def mistral_attention_forward_4_39_original( kv_seq_len += 1 # update past_key_value's seem_tokens and kv caches. - if self.layer_idx == 0: + # [SnapKV] + if use_compresskv: + past_key_value.update_seen_tokens(self.layer_idx, q_len) + kv_seq_len = past_key_value.get_seq_length() + elif self.layer_idx == 0: past_key_value._seen_tokens = kv_seq_len past_key_value.key_cache[self.layer_idx] = key_states past_key_value.value_cache[self.layer_idx] = value_states - else: if should_use_xetla_mm_qkv(self, device): if not hasattr(self, "qkv_proj_qweight"): @@ -1204,40 +1232,47 @@ def mistral_attention_forward_4_39_original( cos, sin, position_ids, "mistral") if past_key_value is not None: - # update the number of seen tokens - if self.layer_idx == 0: - past_key_value._seen_tokens += key_states.shape[-2] - - # reuse k, v, self_attention - # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` - if len(past_key_value.key_cache) <= self.layer_idx: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) + if use_compresskv: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, + query_states, attention_mask, self.num_key_value_groups, + self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) else: - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] + # update the number of seen tokens + if self.layer_idx == 0: + past_key_value._seen_tokens += key_states.shape[-2] - if not enough_kv_room: - # allocate new - new_c_k, new_c_v = extend_kv_cache(bsz, - self.num_key_value_heads, # Support GQA - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) + # reuse k, v, self_attention + # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` + if len(past_key_value.key_cache) <= self.layer_idx: + past_key_value.key_cache.append(key_states) + past_key_value.value_cache.append(value_states) + else: + cache_k = past_key_value.key_cache[self.layer_idx] + cache_v = past_key_value.value_cache[self.layer_idx] - new_c_k[:] = cache_k - new_c_v[:] = cache_v - cache_k = new_c_k - cache_v = new_c_v + if not enough_kv_room: + # allocate new + new_c_k, new_c_v = extend_kv_cache(bsz, + self.num_key_value_heads, # Support GQA + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) - key_states, value_states = append_kv_cache(cache_k, cache_v, - key_states, value_states) + new_c_k[:] = cache_k + new_c_v[:] = cache_v + cache_k = new_c_k + cache_v = new_c_v - # update past_key_value - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states + key_states, value_states = append_kv_cache(cache_k, cache_v, + key_states, + value_states) + + # update past_key_value + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states if not self.training and not hidden_states.requires_grad: fsdp_flag = use_flash_attention(query_states, key_states) @@ -1264,6 +1299,9 @@ def mistral_attention_forward_4_39_original( elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): # new fp16 sdp doesn't require repeat_kv import xe_addons + # [SnapKV] set attention_mask = None + if use_compresskv: + attention_mask = None attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 9d4b44cc..63c71f50 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -479,3 +479,8 @@ def update_past_key_value(past_key_value, key_states, value_states, v_cache = new_v_cache key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states) return key_states, value_states + + +def should_use_compresskv(x: torch.Tensor): + use_compress_kv = os.environ.get("IPEX_LLM_COMPRESS_KV_CACHE", None) + return x.device.type == 'xpu' and use_compress_kv == "1"