* phi3 support compresskv * fix phi3 mtl error * fix conflict with quant kv * fix abnormal on mtl * fix style * use slide windows size to compress kv * support sliding window * fix style * fix style * temp: partial support quant kv * support quant kv with compress kv, todo: model check * temp * fix style * fix style * remove prepare * address comment * default -> 1.8k
		
			
				
	
	
		
			351 lines
		
	
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			351 lines
		
	
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#
 | 
						|
# Copyright 2016 The BigDL Authors.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
#
 | 
						|
 | 
						|
 | 
						|
import torch
 | 
						|
import torch.nn.functional as F
 | 
						|
import torch.nn as nn
 | 
						|
import math
 | 
						|
 | 
						|
from .models.utils import (
 | 
						|
    init_fp8_kv_cache, append_fp8_kv_cache,
 | 
						|
    init_kv_cache, append_kv_cache, extend_kv_cache
 | 
						|
)
 | 
						|
from typing import Optional, Dict, Tuple, Any
 | 
						|
from transformers.cache_utils import DynamicCache
 | 
						|
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
						|
 | 
						|
 | 
						|
class DynamicFp8Cache(DynamicCache):
 | 
						|
    def update(
 | 
						|
        self,
 | 
						|
        key_states: torch.Tensor,
 | 
						|
        value_states: torch.Tensor,
 | 
						|
        layer_idx: int,
 | 
						|
        cache_kwargs: Optional[Dict[str, Any]]=None,
 | 
						|
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
						|
 | 
						|
        batch_size, num_heads, seq_len, head_dim = key_states.shape
 | 
						|
 | 
						|
        if layer_idx == 0:
 | 
						|
            if hasattr(self, "_seen_tokens"):
 | 
						|
                # 4.39 uses `_seen_tokens`
 | 
						|
                self._seen_tokens += seq_len
 | 
						|
            else:
 | 
						|
                # 4.37 uses `seen_tokens`
 | 
						|
                self.seen_tokens += seq_len
 | 
						|
 | 
						|
        # Update the cache
 | 
						|
        if len(self.key_cache) <= layer_idx:
 | 
						|
            k_cache, v_cache = init_fp8_kv_cache(
 | 
						|
                batch_size, num_heads, seq_len, head_dim,
 | 
						|
                device=key_states.device,
 | 
						|
            )
 | 
						|
            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
						|
 | 
						|
            self.key_cache.append(k_cache)
 | 
						|
            self.value_cache.append(v_cache)
 | 
						|
        else:
 | 
						|
            k_cache = self.key_cache[layer_idx]
 | 
						|
            v_cache = self.value_cache[layer_idx]
 | 
						|
            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
						|
            self.key_cache[layer_idx] = k_cache
 | 
						|
            self.value_cache[layer_idx] = v_cache
 | 
						|
 | 
						|
        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
						|
 | 
						|
 | 
						|
class DynamicNormalCache(DynamicCache):
 | 
						|
    KV_ALLOC_BLOCK_LENGTH = 256
 | 
						|
 | 
						|
    def update(
 | 
						|
        self,
 | 
						|
        key_states: torch.Tensor,
 | 
						|
        value_states: torch.Tensor,
 | 
						|
        layer_idx: int,
 | 
						|
        cache_kwargs: Optional[Dict[str, Any]]=None,
 | 
						|
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
						|
 | 
						|
        batch_size, num_heads, seq_len, head_dim = key_states.shape
 | 
						|
 | 
						|
        if layer_idx == 0:
 | 
						|
            if hasattr(self, "_seen_tokens"):
 | 
						|
                # 4.39 uses `_seen_tokens`
 | 
						|
                self._seen_tokens += seq_len
 | 
						|
            else:
 | 
						|
                # 4.37 uses `seen_tokens`
 | 
						|
                self.seen_tokens += seq_len
 | 
						|
 | 
						|
        # Update the cache
 | 
						|
        if len(self.key_cache) <= layer_idx:
 | 
						|
            k_cache, v_cache = init_kv_cache(
 | 
						|
                batch_size, num_heads, head_dim,
 | 
						|
                0, key_states.size(2) + self.KV_ALLOC_BLOCK_LENGTH,
 | 
						|
                key_states.dtype, key_states.device
 | 
						|
            )
 | 
						|
            k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
						|
 | 
						|
            self.key_cache.append(k_cache)
 | 
						|
            self.value_cache.append(v_cache)
 | 
						|
        else:
 | 
						|
            k_cache = self.key_cache[layer_idx]
 | 
						|
            v_cache = self.value_cache[layer_idx]
 | 
						|
 | 
						|
            kv_seq_len = k_cache.size(2) + key_states.size(2)
 | 
						|
            if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
 | 
						|
                new_k_cache, new_v_cache = init_kv_cache(
 | 
						|
                    batch_size, num_heads, head_dim,
 | 
						|
                    k_cache.size(2), kv_seq_len + self.KV_ALLOC_BLOCK_LENGTH,
 | 
						|
                    key_states.dtype, key_states.device
 | 
						|
                )
 | 
						|
                new_k_cache[...] = k_cache[...]
 | 
						|
                new_v_cache[...] = v_cache[...]
 | 
						|
                k_cache = new_k_cache
 | 
						|
                v_cache = new_v_cache
 | 
						|
            k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
						|
            self.key_cache[layer_idx] = k_cache
 | 
						|
            self.value_cache[layer_idx] = v_cache
 | 
						|
 | 
						|
        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
						|
 | 
						|
 | 
						|
# Copied from transformers.models.llama.modeling_llama.repeat_kv
 | 
						|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
						|
    """
 | 
						|
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
 | 
						|
    The hidden states go from (batch, num_key_value_heads, seqlen, head_dim)
 | 
						|
    to (batch, num_attention_heads, seqlen, head_dim)
 | 
						|
    """
 | 
						|
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
 | 
						|
    if n_rep == 1:
 | 
						|
        return hidden_states
 | 
						|
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
 | 
						|
                                                           n_rep, slen, head_dim)
 | 
						|
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 | 
						|
 | 
						|
 | 
						|
# This function is adapted from
 | 
						|
# https://github.com/FasterDecoding/SnapKV/blob/main/snapkv/monkeypatch/snapkv_utils.py
 | 
						|
def compress_kv(attn_config, key_states, query_states, value_states, attention_mask,
 | 
						|
                num_key_value_groups):
 | 
						|
    # check if prefix phase
 | 
						|
    invalidInputError(key_states.shape[-2] == query_states.shape[-2], "kv shape mismatch.")
 | 
						|
    if not hasattr(attn_config, 'window_size'):
 | 
						|
        attn_config.window_size = 32
 | 
						|
    if not hasattr(attn_config, 'max_capacity_prompt'):
 | 
						|
        attn_config.max_capacity_prompt = 1024
 | 
						|
    if not hasattr(attn_config, 'kernel_size'):
 | 
						|
        attn_config.kernel_size = 7
 | 
						|
    if not hasattr(attn_config, 'pooling'):
 | 
						|
        attn_config.pooling = 'maxpool'
 | 
						|
    bsz, num_heads, q_len, head_dim = query_states.shape
 | 
						|
    if q_len <= attn_config.max_capacity_prompt:
 | 
						|
        return key_states, value_states
 | 
						|
    else:
 | 
						|
        sliding_window_size = getattr(attn_config, "sliding_window", None)
 | 
						|
        if sliding_window_size is not None and sliding_window_size <= 2500:
 | 
						|
            return key_states[:, :, -sliding_window_size:, :], \
 | 
						|
                value_states[:, :, -sliding_window_size:, :]
 | 
						|
        else:
 | 
						|
            key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device)
 | 
						|
            attn_weights = torch.matmul(query_states[..., -attn_config.window_size:, :],
 | 
						|
                                        key_states_expand.transpose(2, 3)) / math.sqrt(head_dim)
 | 
						|
            mask = torch.full((attn_config.window_size, attn_config.window_size),
 | 
						|
                              torch.finfo(attn_weights.dtype).min,
 | 
						|
                              device=attn_weights.device)
 | 
						|
            mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
 | 
						|
            mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
 | 
						|
            mask = mask.to(attn_weights.device)
 | 
						|
            attention_mask = mask[None, None, :, :]
 | 
						|
 | 
						|
            attn_weights[:, :, -attn_config.window_size:,
 | 
						|
                         -attn_config.window_size:] += attention_mask
 | 
						|
 | 
						|
            attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
						|
                                                 dtype=torch.float32).to(query_states.dtype)
 | 
						|
            attn_weights_sum = attn_weights[:, :, -attn_config.window_size:,
 | 
						|
                                            :-attn_config.window_size].sum(dim=-2)
 | 
						|
            if attn_config.pooling == 'avgpool':
 | 
						|
                if num_key_value_groups > 1:
 | 
						|
                    attn_cache = F.avg_pool2d(attn_weights_sum,
 | 
						|
                                              kernel_size=(num_key_value_groups,
 | 
						|
                                                           attn_config.kernel_size),
 | 
						|
                                              padding=(0, attn_config.kernel_size//2),
 | 
						|
                                              stride=(num_key_value_groups, 1))
 | 
						|
                else:
 | 
						|
                    attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size,
 | 
						|
                                              padding=attn_config.kernel_size//2, stride=1)
 | 
						|
            elif attn_config.pooling == 'maxpool':
 | 
						|
                if num_key_value_groups > 1:
 | 
						|
                    attn_cache = F.max_pool2d(attn_weights_sum,
 | 
						|
                                              kernel_size=(num_key_value_groups,
 | 
						|
                                                           attn_config.kernel_size),
 | 
						|
                                              padding=(0, attn_config.kernel_size//2),
 | 
						|
                                              stride=(num_key_value_groups, 1))
 | 
						|
                else:
 | 
						|
                    attn_cache = F.max_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size,
 | 
						|
                                              padding=attn_config.kernel_size//2, stride=1)
 | 
						|
            else:
 | 
						|
                invalidInputError(False, 'Pooling method not supported')
 | 
						|
            indices = attn_cache.topk(attn_config.max_capacity_prompt - attn_config.window_size,
 | 
						|
                                      dim=-1).indices
 | 
						|
            indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
 | 
						|
            k_past_compress = key_states[:, :, :-attn_config.window_size, :]\
 | 
						|
                .gather(dim=2, index=indices)
 | 
						|
            v_past_compress = value_states[:, :, :-attn_config.window_size, :]\
 | 
						|
                .gather(dim=2, index=indices)
 | 
						|
            k_cur = key_states[:, :, -attn_config.window_size:, :]
 | 
						|
            v_cur = value_states[:, :, -attn_config.window_size:, :]
 | 
						|
            key_states = torch.cat([k_past_compress, k_cur], dim=2)
 | 
						|
            value_states = torch.cat([v_past_compress, v_cur], dim=2)
 | 
						|
            return key_states, value_states
 | 
						|
 | 
						|
 | 
						|
class DynamicCompressCache(DynamicCache):
 | 
						|
    def __init__(self, quant_kv=False, *args, **kwargs):
 | 
						|
        super().__init__(*args, **kwargs)
 | 
						|
        self.real_kv_len = 0
 | 
						|
        self.quant_kv = quant_kv
 | 
						|
        self.append_kv_func = append_fp8_kv_cache if quant_kv else append_kv_cache
 | 
						|
 | 
						|
    def update_seen_tokens(self, layer_idx, q_len):
 | 
						|
        if layer_idx == 0:
 | 
						|
            if hasattr(self, "_seen_tokens"):
 | 
						|
                # 4.39 uses `_seen_tokens`
 | 
						|
                self._seen_tokens += q_len
 | 
						|
            else:
 | 
						|
                # 4.37 uses `seen_tokens`
 | 
						|
                self.seen_tokens += q_len
 | 
						|
            self.real_kv_len += q_len
 | 
						|
 | 
						|
    def update(
 | 
						|
        self,
 | 
						|
        key_states: torch.Tensor,
 | 
						|
        value_states: torch.Tensor,
 | 
						|
        layer_idx: int,
 | 
						|
        query_states: torch.Tensor,
 | 
						|
        attention_mask: torch.Tensor,
 | 
						|
        num_key_value_groups: int,
 | 
						|
        attn_config: Dict[str, Any],
 | 
						|
        enough_kv_room: bool,
 | 
						|
        KV_CACHE_ALLOC_BLOCK_LENGTH: int,
 | 
						|
        cache_kwargs: Optional[Dict[str, Any]]=None,
 | 
						|
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
						|
 | 
						|
        bsz, num_heads, seq_len, head_dim = key_states.shape
 | 
						|
 | 
						|
        if layer_idx == 0:
 | 
						|
            if hasattr(self, "_seen_tokens"):
 | 
						|
                # 4.39 uses `_seen_tokens`
 | 
						|
                self._seen_tokens += seq_len
 | 
						|
            else:
 | 
						|
                # 4.37 uses `seen_tokens`
 | 
						|
                self.seen_tokens += seq_len
 | 
						|
            self.real_kv_len += seq_len
 | 
						|
 | 
						|
        # Update the cache
 | 
						|
        if len(self.key_cache) <= layer_idx:
 | 
						|
            # First token, compress kv cache
 | 
						|
            key_states_compress, value_states_compress = compress_kv(
 | 
						|
                attn_config=attn_config,
 | 
						|
                key_states=key_states,
 | 
						|
                query_states=query_states,
 | 
						|
                value_states=value_states,
 | 
						|
                attention_mask=attention_mask,
 | 
						|
                num_key_value_groups=num_key_value_groups)
 | 
						|
            self.key_cache.append(key_states_compress)
 | 
						|
            self.value_cache.append(value_states_compress)
 | 
						|
 | 
						|
            if not self.quant_kv:
 | 
						|
                k_cache_compressed, v_cache_compressed = init_kv_cache(
 | 
						|
                    bsz, num_heads, head_dim,
 | 
						|
                    0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
						|
                    key_states.dtype, key_states.device
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                k_cache_compressed, v_cache_compressed = init_fp8_kv_cache(
 | 
						|
                    bsz, num_heads, seq_len, head_dim,
 | 
						|
                    device=key_states.device,
 | 
						|
                )
 | 
						|
            k_cache_compressed, v_cache_compressed = self.append_kv_func(
 | 
						|
                k_cache_compressed, v_cache_compressed,
 | 
						|
                key_states_compress, value_states_compress)
 | 
						|
            self.key_cache[layer_idx] = k_cache_compressed
 | 
						|
            self.value_cache[layer_idx] = v_cache_compressed
 | 
						|
 | 
						|
            if key_states.stride(2) != head_dim:
 | 
						|
                if not self.quant_kv:
 | 
						|
                    k_cache, v_cache = init_kv_cache(
 | 
						|
                        bsz, num_heads, head_dim,
 | 
						|
                        0, key_states.size(2),
 | 
						|
                        key_states.dtype, key_states.device
 | 
						|
                    )
 | 
						|
                else:
 | 
						|
                    k_cache, v_cache = init_fp8_kv_cache(
 | 
						|
                        bsz, num_heads, 0, head_dim, key_states.device
 | 
						|
                    )
 | 
						|
                k_cache, v_cache = self.append_kv_func(k_cache, v_cache,
 | 
						|
                                                       key_states, value_states)
 | 
						|
                return k_cache, v_cache
 | 
						|
            else:
 | 
						|
                return key_states, value_states
 | 
						|
        else:
 | 
						|
            cache_k = self.key_cache[layer_idx]
 | 
						|
            cache_v = self.value_cache[layer_idx]
 | 
						|
            if not enough_kv_room and not self.quant_kv:
 | 
						|
                # allocate new
 | 
						|
                new_c_k, new_c_v = extend_kv_cache(
 | 
						|
                    bsz,
 | 
						|
                    num_heads,  # Support GQA
 | 
						|
                    head_dim,
 | 
						|
                    cache_k.size(2),
 | 
						|
                    cache_k.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
						|
                    dtype=cache_k.dtype,
 | 
						|
                    device=query_states.device)
 | 
						|
 | 
						|
                new_c_k[:] = cache_k
 | 
						|
                new_c_v[:] = cache_v
 | 
						|
                cache_k = new_c_k
 | 
						|
                cache_v = new_c_v
 | 
						|
 | 
						|
            key_states, value_states = self.append_kv_func(cache_k,
 | 
						|
                                                           cache_v,
 | 
						|
                                                           key_states,
 | 
						|
                                                           value_states)
 | 
						|
 | 
						|
            # update past_key_value
 | 
						|
            self.key_cache[layer_idx] = key_states
 | 
						|
            self.value_cache[layer_idx] = value_states
 | 
						|
 | 
						|
            return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
						|
 | 
						|
    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
 | 
						|
        """Returns the sequence length of the cached states. A layer
 | 
						|
        index can be optionally passed."""
 | 
						|
        if len(self.key_cache) <= layer_idx:
 | 
						|
            return 0
 | 
						|
        return self.real_kv_len
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
 | 
						|
                          quantize_kv: Optional[bool] = False) -> "DynamicCache":
 | 
						|
        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
 | 
						|
        cache = cls(quantize_kv)
 | 
						|
        if past_key_values is not None:
 | 
						|
            for layer_idx in range(len(past_key_values)):
 | 
						|
                key_states, value_states = past_key_values[layer_idx]
 | 
						|
                cache.update(key_states, value_states, layer_idx)
 | 
						|
        return cache
 |