Support compress kv (#11642)
* mistral snapkv * update * mtl update * update * update * update * add comments * style fix * fix style * support llama * llama use compress kv * support mistral 4.40 * fix style * support diff transformers versions * move snapkv util to kv * fix style * meet comments & small fix * revert all in one * fix indent --------- Co-authored-by: leonardozcm <leonardo1997zcm@gmail.com>
This commit is contained in:
		
							parent
							
								
									6bcdc6cc8f
								
							
						
					
					
						commit
						fc7f8feb83
					
				
					 5 changed files with 422 additions and 152 deletions
				
			
		| 
						 | 
				
			
			@ -1443,14 +1443,14 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
            if version.parse(trans_version) >= version.parse("4.36.0"):
 | 
			
		||||
                from ipex_llm.transformers.models.mistral import mistral_model_forward_4_36
 | 
			
		||||
                if version.parse(trans_version) >= version.parse("4.39.0"):
 | 
			
		||||
                    from ipex_llm.transformers.models.mistral import mistral_attention_forward_4_39
 | 
			
		||||
                    from ipex_llm.transformers.models.mistral import \
 | 
			
		||||
                        mistral_attention_forward_4_39
 | 
			
		||||
                    convert_forward(model,
 | 
			
		||||
                                    module.MistralAttention,
 | 
			
		||||
                                    mistral_attention_forward_4_39
 | 
			
		||||
                                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    from ipex_llm.transformers.models.mistral import mistral_attention_forward_4_36
 | 
			
		||||
 | 
			
		||||
                    convert_forward(model,
 | 
			
		||||
                                    module.MistralAttention,
 | 
			
		||||
                                    mistral_attention_forward_4_36
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,13 +16,17 @@
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
from .models.utils import (
 | 
			
		||||
    init_fp8_kv_cache, append_fp8_kv_cache,
 | 
			
		||||
    init_kv_cache, append_kv_cache
 | 
			
		||||
    init_kv_cache, append_kv_cache, extend_kv_cache
 | 
			
		||||
)
 | 
			
		||||
from typing import Optional, Dict, Tuple, Any
 | 
			
		||||
from transformers.cache_utils import DynamicCache
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DynamicFp8Cache(DynamicCache):
 | 
			
		||||
| 
						 | 
				
			
			@ -116,3 +120,178 @@ class DynamicNormalCache(DynamicCache):
 | 
			
		|||
            self.value_cache[layer_idx] = v_cache
 | 
			
		||||
 | 
			
		||||
        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
 | 
			
		||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
 | 
			
		||||
    The hidden states go from (batch, num_key_value_heads, seqlen, head_dim)
 | 
			
		||||
    to (batch, num_attention_heads, seqlen, head_dim)
 | 
			
		||||
    """
 | 
			
		||||
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
 | 
			
		||||
    if n_rep == 1:
 | 
			
		||||
        return hidden_states
 | 
			
		||||
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
 | 
			
		||||
                                                           n_rep, slen, head_dim)
 | 
			
		||||
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# This function is adapted from
 | 
			
		||||
# https://github.com/FasterDecoding/SnapKV/blob/main/snapkv/monkeypatch/snapkv_utils.py
 | 
			
		||||
def compress_kv(attn_config, key_states, query_states, value_states, attention_mask,
 | 
			
		||||
                num_key_value_groups):
 | 
			
		||||
    # check if prefix phase
 | 
			
		||||
    invalidInputError(key_states.shape[-2] == query_states.shape[-2], "kv shape mismatch.")
 | 
			
		||||
    if not hasattr(attn_config, 'window_size'):
 | 
			
		||||
        attn_config.window_size = 32
 | 
			
		||||
    if not hasattr(attn_config, 'max_capacity_prompt'):
 | 
			
		||||
        attn_config.max_capacity_prompt = 512
 | 
			
		||||
    if not hasattr(attn_config, 'kernel_size'):
 | 
			
		||||
        attn_config.kernel_size = 5
 | 
			
		||||
    if not hasattr(attn_config, 'pooling'):
 | 
			
		||||
        attn_config.pooling = 'avgpool'
 | 
			
		||||
    bsz, num_heads, q_len, head_dim = query_states.shape
 | 
			
		||||
    if q_len < attn_config.max_capacity_prompt:
 | 
			
		||||
        return key_states, value_states
 | 
			
		||||
    else:
 | 
			
		||||
        key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device)
 | 
			
		||||
        attn_weights = torch.matmul(query_states[..., -attn_config.window_size:, :],
 | 
			
		||||
                                    key_states_expand.transpose(2, 3)) / math.sqrt(head_dim)
 | 
			
		||||
        mask = torch.full((attn_config.window_size, attn_config.window_size),
 | 
			
		||||
                          torch.finfo(attn_weights.dtype).min,
 | 
			
		||||
                          device=attn_weights.device)
 | 
			
		||||
        mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
 | 
			
		||||
        mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
 | 
			
		||||
        mask = mask.to(attn_weights.device)
 | 
			
		||||
        attention_mask = mask[None, None, :, :]
 | 
			
		||||
 | 
			
		||||
        attn_weights[:, :, -attn_config.window_size:, -attn_config.window_size:] += attention_mask
 | 
			
		||||
 | 
			
		||||
        attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
			
		||||
                                             dtype=torch.float32).to(query_states.dtype)
 | 
			
		||||
        attn_weights_sum = attn_weights[:, :, -attn_config.window_size:,
 | 
			
		||||
                                        :-attn_config.window_size].sum(dim=-2)
 | 
			
		||||
        if attn_config.pooling == 'avgpool':
 | 
			
		||||
            if num_key_value_groups > 1:
 | 
			
		||||
                attn_cache = F.avg_pool2d(attn_weights_sum, kernel_size=(num_key_value_groups,
 | 
			
		||||
                                                                         attn_config.kernel_size),
 | 
			
		||||
                                          padding=(0, attn_config.kernel_size//2),
 | 
			
		||||
                                          stride=(num_key_value_groups, 1))
 | 
			
		||||
            else:
 | 
			
		||||
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size,
 | 
			
		||||
                                          padding=attn_config.kernel_size//2, stride=1)
 | 
			
		||||
        elif attn_config.pooling == 'maxpool':
 | 
			
		||||
            if num_key_value_groups > 1:
 | 
			
		||||
                attn_cache = F.max_pool2d(attn_weights_sum,
 | 
			
		||||
                                          kernel_size=(num_key_value_groups,
 | 
			
		||||
                                                       attn_config.kernel_size),
 | 
			
		||||
                                          padding=(0, attn_config.kernel_size//2),
 | 
			
		||||
                                          stride=(num_key_value_groups, 1))
 | 
			
		||||
            else:
 | 
			
		||||
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size,
 | 
			
		||||
                                          padding=attn_config.kernel_size//2, stride=1)
 | 
			
		||||
        else:
 | 
			
		||||
            invalidInputError(False, 'Pooling method not supported')
 | 
			
		||||
        indices = attn_cache.topk(attn_config.max_capacity_prompt - attn_config.window_size,
 | 
			
		||||
                                  dim=-1).indices
 | 
			
		||||
        indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
 | 
			
		||||
        k_past_compress = key_states[:, :, :-attn_config.window_size, :].gather(dim=2,
 | 
			
		||||
                                                                                index=indices)
 | 
			
		||||
        v_past_compress = value_states[:, :, :-attn_config.window_size, :].gather(dim=2,
 | 
			
		||||
                                                                                  index=indices)
 | 
			
		||||
        k_cur = key_states[:, :, -attn_config.window_size:, :]
 | 
			
		||||
        v_cur = value_states[:, :, -attn_config.window_size:, :]
 | 
			
		||||
        key_states = torch.cat([k_past_compress, k_cur], dim=2)
 | 
			
		||||
        value_states = torch.cat([v_past_compress, v_cur], dim=2)
 | 
			
		||||
        return key_states, value_states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DynamicCompressCache(DynamicCache):
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        self.real_kv_len = 0
 | 
			
		||||
 | 
			
		||||
    def update_seen_tokens(self, layer_idx, q_len):
 | 
			
		||||
        if layer_idx == 0:
 | 
			
		||||
            if hasattr(self, "_seen_tokens"):
 | 
			
		||||
                # 4.39 uses `_seen_tokens`
 | 
			
		||||
                self._seen_tokens += q_len
 | 
			
		||||
            else:
 | 
			
		||||
                # 4.37 uses `seen_tokens`
 | 
			
		||||
                self.seen_tokens += q_len
 | 
			
		||||
            self.real_kv_len += q_len
 | 
			
		||||
 | 
			
		||||
    def update(
 | 
			
		||||
        self,
 | 
			
		||||
        key_states: torch.Tensor,
 | 
			
		||||
        value_states: torch.Tensor,
 | 
			
		||||
        layer_idx: int,
 | 
			
		||||
        query_states: torch.Tensor,
 | 
			
		||||
        attention_mask: torch.Tensor,
 | 
			
		||||
        num_key_value_groups: int,
 | 
			
		||||
        attn_config: Dict[str, Any],
 | 
			
		||||
        enough_kv_room: bool,
 | 
			
		||||
        KV_CACHE_ALLOC_BLOCK_LENGTH: int,
 | 
			
		||||
        cache_kwargs: Optional[Dict[str, Any]]=None,
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
 | 
			
		||||
        bsz, num_heads, seq_len, head_dim = key_states.shape
 | 
			
		||||
 | 
			
		||||
        if layer_idx == 0:
 | 
			
		||||
            if hasattr(self, "_seen_tokens"):
 | 
			
		||||
                # 4.39 uses `_seen_tokens`
 | 
			
		||||
                self._seen_tokens += seq_len
 | 
			
		||||
            else:
 | 
			
		||||
                # 4.37 uses `seen_tokens`
 | 
			
		||||
                self.seen_tokens += seq_len
 | 
			
		||||
            self.real_kv_len += seq_len
 | 
			
		||||
 | 
			
		||||
        # Update the cache
 | 
			
		||||
        if len(self.key_cache) <= layer_idx:
 | 
			
		||||
            # First token, compress kv cache
 | 
			
		||||
            key_states_compress, value_states_compress = compress_kv(
 | 
			
		||||
                attn_config=attn_config,
 | 
			
		||||
                key_states=key_states,
 | 
			
		||||
                query_states=query_states,
 | 
			
		||||
                value_states=value_states,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                num_key_value_groups=num_key_value_groups)
 | 
			
		||||
            self.key_cache.append(key_states_compress)
 | 
			
		||||
            self.value_cache.append(value_states_compress)
 | 
			
		||||
            return key_states, value_states
 | 
			
		||||
        else:
 | 
			
		||||
            cache_k = self.key_cache[layer_idx]
 | 
			
		||||
            cache_v = self.value_cache[layer_idx]
 | 
			
		||||
            if not enough_kv_room:
 | 
			
		||||
                # allocate new
 | 
			
		||||
                new_c_k, new_c_v = extend_kv_cache(bsz,
 | 
			
		||||
                                                   num_heads,  # Support GQA
 | 
			
		||||
                                                   head_dim,
 | 
			
		||||
                                                   cache_k.size(2),
 | 
			
		||||
                                                   cache_k.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
			
		||||
                                                   dtype=cache_k.dtype,
 | 
			
		||||
                                                   device=query_states.device)
 | 
			
		||||
 | 
			
		||||
                new_c_k[:] = cache_k
 | 
			
		||||
                new_c_v[:] = cache_v
 | 
			
		||||
                cache_k = new_c_k
 | 
			
		||||
                cache_v = new_c_v
 | 
			
		||||
 | 
			
		||||
            key_states, value_states = append_kv_cache(cache_k,
 | 
			
		||||
                                                       cache_v,
 | 
			
		||||
                                                       key_states,
 | 
			
		||||
                                                       value_states)
 | 
			
		||||
 | 
			
		||||
            # update past_key_value
 | 
			
		||||
            self.key_cache[layer_idx] = key_states
 | 
			
		||||
            self.value_cache[layer_idx] = value_states
 | 
			
		||||
 | 
			
		||||
            return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
			
		||||
 | 
			
		||||
    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
 | 
			
		||||
        """Returns the sequence length of the cached states. A layer
 | 
			
		||||
        index can be optionally passed."""
 | 
			
		||||
        if len(self.key_cache) <= layer_idx:
 | 
			
		||||
            return 0
 | 
			
		||||
        return self.real_kv_len
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -42,7 +42,7 @@ import torch.nn.functional as F
 | 
			
		|||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import SILU
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
 | 
			
		||||
    restore_fp8_kv_cache, use_quantize_kv_cache
 | 
			
		||||
    restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv
 | 
			
		||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
			
		||||
    apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
| 
						 | 
				
			
			@ -113,12 +113,18 @@ def llama_model_forward_4_36(
 | 
			
		|||
    output_hidden_states: Optional[bool] = None,
 | 
			
		||||
    return_dict: Optional[bool] = None,
 | 
			
		||||
) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    input = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif should_use_compresskv(input):
 | 
			
		||||
            # if use quantize kv, compress kv will be ignored now
 | 
			
		||||
            if not isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(
 | 
			
		||||
                    past_key_values)
 | 
			
		||||
    return llama_model_forward_4_36_internal(
 | 
			
		||||
        self=self,
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
| 
						 | 
				
			
			@ -146,12 +152,18 @@ def llama_model_forward_4_38(
 | 
			
		|||
    return_dict: Optional[bool] = None,
 | 
			
		||||
    cache_position: Optional[torch.LongTensor] = None,
 | 
			
		||||
) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    input = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif should_use_compresskv(input):
 | 
			
		||||
            # if use quantize kv, compress kv will be ignored now
 | 
			
		||||
            if not isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(
 | 
			
		||||
                    past_key_values)
 | 
			
		||||
    return llama_model_forward_4_38_internal(
 | 
			
		||||
        self=self,
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
| 
						 | 
				
			
			@ -180,12 +192,18 @@ def llama_model_forward_4_41(
 | 
			
		|||
    return_dict: Optional[bool] = None,
 | 
			
		||||
    cache_position: Optional[torch.LongTensor] = None,
 | 
			
		||||
) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    input = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif should_use_compresskv(input):
 | 
			
		||||
            # if use quantize kv, compress kv will be ignored now
 | 
			
		||||
            if not isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(
 | 
			
		||||
                    past_key_values)
 | 
			
		||||
    return llama_model_forward_4_41_internal(
 | 
			
		||||
        self=self,
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
| 
						 | 
				
			
			@ -1267,6 +1285,9 @@ def llama_attention_forward_4_41_original(
 | 
			
		|||
    # for flash attention
 | 
			
		||||
    original_dtype = hidden_states.dtype
 | 
			
		||||
 | 
			
		||||
    # [SnapKV]
 | 
			
		||||
    use_compresskv = should_use_compresskv(hidden_states)
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
 | 
			
		||||
    no_tp = not self.config.pretraining_tp > 1
 | 
			
		||||
| 
						 | 
				
			
			@ -1299,7 +1320,11 @@ def llama_attention_forward_4_41_original(
 | 
			
		|||
                                                                       self.rotary_emb.base,)
 | 
			
		||||
        kv_seq_len += 1
 | 
			
		||||
        # update past_key_value's seem_tokens and kv caches.
 | 
			
		||||
        if self.layer_idx == 0:
 | 
			
		||||
        # [SnapKV]
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            past_key_value.update_seen_tokens(self.layer_idx, q_len)
 | 
			
		||||
            kv_seq_len = past_key_value.get_seq_length()
 | 
			
		||||
        elif self.layer_idx == 0:
 | 
			
		||||
            past_key_value._seen_tokens = kv_seq_len
 | 
			
		||||
        past_key_value.key_cache[self.layer_idx] = key_states
 | 
			
		||||
        past_key_value.value_cache[self.layer_idx] = value_states
 | 
			
		||||
| 
						 | 
				
			
			@ -1404,6 +1429,12 @@ def llama_attention_forward_4_41_original(
 | 
			
		|||
                                                                cos, sin, position_ids, "llama")
 | 
			
		||||
 | 
			
		||||
        if past_key_value is not None:
 | 
			
		||||
            if use_compresskv:
 | 
			
		||||
                key_states, value_states = past_key_value.update(
 | 
			
		||||
                    key_states, value_states, self.layer_idx,
 | 
			
		||||
                    query_states, attention_mask, self.num_key_value_groups,
 | 
			
		||||
                    self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
 | 
			
		||||
            else:
 | 
			
		||||
                # update the number of seen tokens
 | 
			
		||||
                if self.layer_idx == 0:
 | 
			
		||||
                    past_key_value._seen_tokens += key_states.shape[-2]
 | 
			
		||||
| 
						 | 
				
			
			@ -1443,7 +1474,6 @@ def llama_attention_forward_4_41_original(
 | 
			
		|||
 | 
			
		||||
    if cache_position is not None:
 | 
			
		||||
        new_attention_mask = attention_mask[:, :, :, 0:kv_seq_len]
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        new_attention_mask = attention_mask
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1461,6 +1491,9 @@ def llama_attention_forward_4_41_original(
 | 
			
		|||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            # [SnapKV] set attention_mask = None
 | 
			
		||||
            new_attention_mask = None
 | 
			
		||||
        attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
			
		||||
                                    new_attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
| 
						 | 
				
			
			@ -1791,6 +1824,9 @@ def llama_attention_forward_4_38_original(
 | 
			
		|||
    # for flash attention
 | 
			
		||||
    original_dtype = hidden_states.dtype
 | 
			
		||||
 | 
			
		||||
    # [SnapKV]
 | 
			
		||||
    use_compresskv = should_use_compresskv(hidden_states)
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
 | 
			
		||||
    no_tp = not self.config.pretraining_tp > 1
 | 
			
		||||
| 
						 | 
				
			
			@ -1823,11 +1859,14 @@ def llama_attention_forward_4_38_original(
 | 
			
		|||
                                                                       self.rotary_emb.base,)
 | 
			
		||||
        kv_seq_len += 1
 | 
			
		||||
        # update past_key_value's seem_tokens and kv caches.
 | 
			
		||||
        if self.layer_idx == 0:
 | 
			
		||||
        # [SnapKV]
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            past_key_value.update_seen_tokens(self.layer_idx, q_len)
 | 
			
		||||
            kv_seq_len = past_key_value.get_seq_length()
 | 
			
		||||
        elif self.layer_idx == 0:
 | 
			
		||||
            past_key_value.seen_tokens = kv_seq_len
 | 
			
		||||
        past_key_value.key_cache[self.layer_idx] = key_states
 | 
			
		||||
        past_key_value.value_cache[self.layer_idx] = value_states
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        if self.config.pretraining_tp > 1:
 | 
			
		||||
            key_value_slicing = ((self.num_key_value_heads * self.head_dim) //
 | 
			
		||||
| 
						 | 
				
			
			@ -1928,6 +1967,12 @@ def llama_attention_forward_4_38_original(
 | 
			
		|||
                                                                cos, sin, position_ids, "llama")
 | 
			
		||||
 | 
			
		||||
        if past_key_value is not None:
 | 
			
		||||
            if use_compresskv:
 | 
			
		||||
                key_states, value_states = past_key_value.update(
 | 
			
		||||
                    key_states, value_states, self.layer_idx,
 | 
			
		||||
                    query_states, attention_mask, self.num_key_value_groups,
 | 
			
		||||
                    self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
 | 
			
		||||
            else:
 | 
			
		||||
                # update the number of seen tokens
 | 
			
		||||
                if self.layer_idx == 0:
 | 
			
		||||
                    past_key_value.seen_tokens += key_states.shape[-2]
 | 
			
		||||
| 
						 | 
				
			
			@ -1984,6 +2029,9 @@ def llama_attention_forward_4_38_original(
 | 
			
		|||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            # [SnapKV] set attention_mask = None
 | 
			
		||||
            new_attention_mask = None
 | 
			
		||||
        attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
			
		||||
                                    new_attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
| 
						 | 
				
			
			@ -2515,11 +2563,11 @@ def llama_model_forward_4_41_internal(
 | 
			
		|||
        all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
    next_cache = None
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        next_cache = (
 | 
			
		||||
            next_decoder_cache.to_legacy_cache()
 | 
			
		||||
            if not isinstance(next_decoder_cache, DynamicFp8Cache)
 | 
			
		||||
            if not isinstance(next_decoder_cache, (DynamicFp8Cache, DynamicCompressCache))
 | 
			
		||||
            else next_decoder_cache
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -2645,11 +2693,11 @@ def llama_model_forward_4_38_internal(
 | 
			
		|||
        all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
    next_cache = None
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        next_cache = (
 | 
			
		||||
            next_decoder_cache.to_legacy_cache()
 | 
			
		||||
            if not isinstance(next_decoder_cache, DynamicFp8Cache)
 | 
			
		||||
            if not isinstance(next_decoder_cache, (DynamicFp8Cache, DynamicCompressCache))
 | 
			
		||||
            else next_decoder_cache
 | 
			
		||||
        )
 | 
			
		||||
    if not return_dict:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,7 +46,7 @@ from transformers.models.mistral.modeling_mistral import MistralModel
 | 
			
		|||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
 | 
			
		||||
    restore_fp8_kv_cache, use_quantize_kv_cache
 | 
			
		||||
    restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
 | 
			
		||||
    apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
			
		||||
| 
						 | 
				
			
			@ -202,11 +202,17 @@ def mistral_model_forward_4_36(
 | 
			
		|||
    output_hidden_states: Optional[bool] = None,
 | 
			
		||||
    return_dict: Optional[bool] = None,
 | 
			
		||||
) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif should_use_compresskv(input_ids):
 | 
			
		||||
            # if use quantize kv, compress kv will be ignored now
 | 
			
		||||
            if not isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(
 | 
			
		||||
                    past_key_values)
 | 
			
		||||
    return MistralModel.forward(
 | 
			
		||||
        self=self,
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
| 
						 | 
				
			
			@ -890,6 +896,9 @@ def mistral_attention_forward_4_36_original(
 | 
			
		|||
    # for flash attention
 | 
			
		||||
    original_dtype = hidden_states.dtype
 | 
			
		||||
 | 
			
		||||
    # [SnapKV]
 | 
			
		||||
    use_compresskv = should_use_compresskv(hidden_states)
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
			
		||||
    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
			
		||||
| 
						 | 
				
			
			@ -920,7 +929,11 @@ def mistral_attention_forward_4_36_original(
 | 
			
		|||
        kv_seq_len += 1
 | 
			
		||||
 | 
			
		||||
        # update past_key_value's seem_tokens and kv caches.
 | 
			
		||||
        if self.layer_idx == 0:
 | 
			
		||||
        # [SnapKV]
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            past_key_value.update_seen_tokens(self.layer_idx, q_len)
 | 
			
		||||
            kv_seq_len = past_key_value.get_seq_length()
 | 
			
		||||
        elif self.layer_idx == 0:
 | 
			
		||||
            past_key_value.seen_tokens = kv_seq_len
 | 
			
		||||
        past_key_value.key_cache[self.layer_idx] = key_states
 | 
			
		||||
        past_key_value.value_cache[self.layer_idx] = value_states
 | 
			
		||||
| 
						 | 
				
			
			@ -975,6 +988,12 @@ def mistral_attention_forward_4_36_original(
 | 
			
		|||
                                                            cos, sin, position_ids, "mistral")
 | 
			
		||||
 | 
			
		||||
        if past_key_value is not None:
 | 
			
		||||
            if use_compresskv:
 | 
			
		||||
                key_states, value_states = past_key_value.update(
 | 
			
		||||
                    key_states, value_states, self.layer_idx,
 | 
			
		||||
                    query_states, attention_mask, self.num_key_value_groups,
 | 
			
		||||
                    self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
 | 
			
		||||
            else:
 | 
			
		||||
                # update the number of seen tokens
 | 
			
		||||
                if self.layer_idx == 0:
 | 
			
		||||
                    past_key_value.seen_tokens += key_states.shape[-2]
 | 
			
		||||
| 
						 | 
				
			
			@ -1035,6 +1054,9 @@ def mistral_attention_forward_4_36_original(
 | 
			
		|||
    elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        # new fp16 sdp doesn't require repeat_kv
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        # [SnapKV] set attention_mask = None
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            attention_mask = None
 | 
			
		||||
        attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
| 
						 | 
				
			
			@ -1119,6 +1141,9 @@ def mistral_attention_forward_4_39_original(
 | 
			
		|||
    # for flash attention
 | 
			
		||||
    original_dtype = hidden_states.dtype
 | 
			
		||||
 | 
			
		||||
    # [SnapKV]
 | 
			
		||||
    use_compresskv = should_use_compresskv(hidden_states)
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
			
		||||
    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
			
		||||
| 
						 | 
				
			
			@ -1149,11 +1174,14 @@ def mistral_attention_forward_4_39_original(
 | 
			
		|||
        kv_seq_len += 1
 | 
			
		||||
 | 
			
		||||
        # update past_key_value's seem_tokens and kv caches.
 | 
			
		||||
        if self.layer_idx == 0:
 | 
			
		||||
        # [SnapKV]
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            past_key_value.update_seen_tokens(self.layer_idx, q_len)
 | 
			
		||||
            kv_seq_len = past_key_value.get_seq_length()
 | 
			
		||||
        elif self.layer_idx == 0:
 | 
			
		||||
            past_key_value._seen_tokens = kv_seq_len
 | 
			
		||||
        past_key_value.key_cache[self.layer_idx] = key_states
 | 
			
		||||
        past_key_value.value_cache[self.layer_idx] = value_states
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        if should_use_xetla_mm_qkv(self, device):
 | 
			
		||||
            if not hasattr(self, "qkv_proj_qweight"):
 | 
			
		||||
| 
						 | 
				
			
			@ -1204,6 +1232,12 @@ def mistral_attention_forward_4_39_original(
 | 
			
		|||
                                                            cos, sin, position_ids, "mistral")
 | 
			
		||||
 | 
			
		||||
        if past_key_value is not None:
 | 
			
		||||
            if use_compresskv:
 | 
			
		||||
                key_states, value_states = past_key_value.update(
 | 
			
		||||
                    key_states, value_states, self.layer_idx,
 | 
			
		||||
                    query_states, attention_mask, self.num_key_value_groups,
 | 
			
		||||
                    self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
 | 
			
		||||
            else:
 | 
			
		||||
                # update the number of seen tokens
 | 
			
		||||
                if self.layer_idx == 0:
 | 
			
		||||
                    past_key_value._seen_tokens += key_states.shape[-2]
 | 
			
		||||
| 
						 | 
				
			
			@ -1233,7 +1267,8 @@ def mistral_attention_forward_4_39_original(
 | 
			
		|||
                        cache_v = new_c_v
 | 
			
		||||
 | 
			
		||||
                    key_states, value_states = append_kv_cache(cache_k, cache_v,
 | 
			
		||||
                                                           key_states, value_states)
 | 
			
		||||
                                                               key_states,
 | 
			
		||||
                                                               value_states)
 | 
			
		||||
 | 
			
		||||
                    # update past_key_value
 | 
			
		||||
                    past_key_value.key_cache[self.layer_idx] = key_states
 | 
			
		||||
| 
						 | 
				
			
			@ -1264,6 +1299,9 @@ def mistral_attention_forward_4_39_original(
 | 
			
		|||
    elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        # new fp16 sdp doesn't require repeat_kv
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        # [SnapKV] set attention_mask = None
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            attention_mask = None
 | 
			
		||||
        attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -479,3 +479,8 @@ def update_past_key_value(past_key_value, key_states, value_states,
 | 
			
		|||
                v_cache = new_v_cache
 | 
			
		||||
            key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
			
		||||
    return key_states, value_states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_use_compresskv(x: torch.Tensor):
 | 
			
		||||
    use_compress_kv = os.environ.get("IPEX_LLM_COMPRESS_KV_CACHE", None)
 | 
			
		||||
    return x.device.type == 'xpu' and use_compress_kv == "1"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue