Phi3 support compresskv (#11733)
* phi3 support compresskv * fix phi3 mtl error * fix conflict with quant kv * fix abnormal on mtl * fix style * use slide windows size to compress kv * support sliding window * fix style * fix style * temp: partial support quant kv * support quant kv with compress kv, todo: model check * temp * fix style * fix style * remove prepare * address comment * default -> 1.8k
This commit is contained in:
		
							parent
							
								
									d8808cc2e3
								
							
						
					
					
						commit
						dd46c141bd
					
				
					 3 changed files with 146 additions and 82 deletions
				
			
		| 
						 | 
					@ -155,62 +155,71 @@ def compress_kv(attn_config, key_states, query_states, value_states, attention_m
 | 
				
			||||||
    if q_len <= attn_config.max_capacity_prompt:
 | 
					    if q_len <= attn_config.max_capacity_prompt:
 | 
				
			||||||
        return key_states, value_states
 | 
					        return key_states, value_states
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device)
 | 
					        sliding_window_size = getattr(attn_config, "sliding_window", None)
 | 
				
			||||||
        attn_weights = torch.matmul(query_states[..., -attn_config.window_size:, :],
 | 
					        if sliding_window_size is not None and sliding_window_size <= 2500:
 | 
				
			||||||
                                    key_states_expand.transpose(2, 3)) / math.sqrt(head_dim)
 | 
					            return key_states[:, :, -sliding_window_size:, :], \
 | 
				
			||||||
        mask = torch.full((attn_config.window_size, attn_config.window_size),
 | 
					                value_states[:, :, -sliding_window_size:, :]
 | 
				
			||||||
                          torch.finfo(attn_weights.dtype).min,
 | 
					 | 
				
			||||||
                          device=attn_weights.device)
 | 
					 | 
				
			||||||
        mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
 | 
					 | 
				
			||||||
        mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
 | 
					 | 
				
			||||||
        mask = mask.to(attn_weights.device)
 | 
					 | 
				
			||||||
        attention_mask = mask[None, None, :, :]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        attn_weights[:, :, -attn_config.window_size:, -attn_config.window_size:] += attention_mask
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
					 | 
				
			||||||
                                             dtype=torch.float32).to(query_states.dtype)
 | 
					 | 
				
			||||||
        attn_weights_sum = attn_weights[:, :, -attn_config.window_size:,
 | 
					 | 
				
			||||||
                                        :-attn_config.window_size].sum(dim=-2)
 | 
					 | 
				
			||||||
        if attn_config.pooling == 'avgpool':
 | 
					 | 
				
			||||||
            if num_key_value_groups > 1:
 | 
					 | 
				
			||||||
                attn_cache = F.avg_pool2d(attn_weights_sum, kernel_size=(num_key_value_groups,
 | 
					 | 
				
			||||||
                                                                         attn_config.kernel_size),
 | 
					 | 
				
			||||||
                                          padding=(0, attn_config.kernel_size//2),
 | 
					 | 
				
			||||||
                                          stride=(num_key_value_groups, 1))
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size,
 | 
					 | 
				
			||||||
                                          padding=attn_config.kernel_size//2, stride=1)
 | 
					 | 
				
			||||||
        elif attn_config.pooling == 'maxpool':
 | 
					 | 
				
			||||||
            if num_key_value_groups > 1:
 | 
					 | 
				
			||||||
                attn_cache = F.max_pool2d(attn_weights_sum,
 | 
					 | 
				
			||||||
                                          kernel_size=(num_key_value_groups,
 | 
					 | 
				
			||||||
                                                       attn_config.kernel_size),
 | 
					 | 
				
			||||||
                                          padding=(0, attn_config.kernel_size//2),
 | 
					 | 
				
			||||||
                                          stride=(num_key_value_groups, 1))
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size,
 | 
					 | 
				
			||||||
                                          padding=attn_config.kernel_size//2, stride=1)
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            invalidInputError(False, 'Pooling method not supported')
 | 
					            key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device)
 | 
				
			||||||
        indices = attn_cache.topk(attn_config.max_capacity_prompt - attn_config.window_size,
 | 
					            attn_weights = torch.matmul(query_states[..., -attn_config.window_size:, :],
 | 
				
			||||||
                                  dim=-1).indices
 | 
					                                        key_states_expand.transpose(2, 3)) / math.sqrt(head_dim)
 | 
				
			||||||
        indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
 | 
					            mask = torch.full((attn_config.window_size, attn_config.window_size),
 | 
				
			||||||
        k_past_compress = key_states[:, :, :-attn_config.window_size, :].gather(dim=2,
 | 
					                              torch.finfo(attn_weights.dtype).min,
 | 
				
			||||||
                                                                                index=indices)
 | 
					                              device=attn_weights.device)
 | 
				
			||||||
        v_past_compress = value_states[:, :, :-attn_config.window_size, :].gather(dim=2,
 | 
					            mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
 | 
				
			||||||
                                                                                  index=indices)
 | 
					            mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
 | 
				
			||||||
        k_cur = key_states[:, :, -attn_config.window_size:, :]
 | 
					            mask = mask.to(attn_weights.device)
 | 
				
			||||||
        v_cur = value_states[:, :, -attn_config.window_size:, :]
 | 
					            attention_mask = mask[None, None, :, :]
 | 
				
			||||||
        key_states = torch.cat([k_past_compress, k_cur], dim=2)
 | 
					
 | 
				
			||||||
        value_states = torch.cat([v_past_compress, v_cur], dim=2)
 | 
					            attn_weights[:, :, -attn_config.window_size:,
 | 
				
			||||||
        return key_states, value_states
 | 
					                         -attn_config.window_size:] += attention_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
				
			||||||
 | 
					                                                 dtype=torch.float32).to(query_states.dtype)
 | 
				
			||||||
 | 
					            attn_weights_sum = attn_weights[:, :, -attn_config.window_size:,
 | 
				
			||||||
 | 
					                                            :-attn_config.window_size].sum(dim=-2)
 | 
				
			||||||
 | 
					            if attn_config.pooling == 'avgpool':
 | 
				
			||||||
 | 
					                if num_key_value_groups > 1:
 | 
				
			||||||
 | 
					                    attn_cache = F.avg_pool2d(attn_weights_sum,
 | 
				
			||||||
 | 
					                                              kernel_size=(num_key_value_groups,
 | 
				
			||||||
 | 
					                                                           attn_config.kernel_size),
 | 
				
			||||||
 | 
					                                              padding=(0, attn_config.kernel_size//2),
 | 
				
			||||||
 | 
					                                              stride=(num_key_value_groups, 1))
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size,
 | 
				
			||||||
 | 
					                                              padding=attn_config.kernel_size//2, stride=1)
 | 
				
			||||||
 | 
					            elif attn_config.pooling == 'maxpool':
 | 
				
			||||||
 | 
					                if num_key_value_groups > 1:
 | 
				
			||||||
 | 
					                    attn_cache = F.max_pool2d(attn_weights_sum,
 | 
				
			||||||
 | 
					                                              kernel_size=(num_key_value_groups,
 | 
				
			||||||
 | 
					                                                           attn_config.kernel_size),
 | 
				
			||||||
 | 
					                                              padding=(0, attn_config.kernel_size//2),
 | 
				
			||||||
 | 
					                                              stride=(num_key_value_groups, 1))
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    attn_cache = F.max_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size,
 | 
				
			||||||
 | 
					                                              padding=attn_config.kernel_size//2, stride=1)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                invalidInputError(False, 'Pooling method not supported')
 | 
				
			||||||
 | 
					            indices = attn_cache.topk(attn_config.max_capacity_prompt - attn_config.window_size,
 | 
				
			||||||
 | 
					                                      dim=-1).indices
 | 
				
			||||||
 | 
					            indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
 | 
				
			||||||
 | 
					            k_past_compress = key_states[:, :, :-attn_config.window_size, :]\
 | 
				
			||||||
 | 
					                .gather(dim=2, index=indices)
 | 
				
			||||||
 | 
					            v_past_compress = value_states[:, :, :-attn_config.window_size, :]\
 | 
				
			||||||
 | 
					                .gather(dim=2, index=indices)
 | 
				
			||||||
 | 
					            k_cur = key_states[:, :, -attn_config.window_size:, :]
 | 
				
			||||||
 | 
					            v_cur = value_states[:, :, -attn_config.window_size:, :]
 | 
				
			||||||
 | 
					            key_states = torch.cat([k_past_compress, k_cur], dim=2)
 | 
				
			||||||
 | 
					            value_states = torch.cat([v_past_compress, v_cur], dim=2)
 | 
				
			||||||
 | 
					            return key_states, value_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DynamicCompressCache(DynamicCache):
 | 
					class DynamicCompressCache(DynamicCache):
 | 
				
			||||||
    def __init__(self, *args, **kwargs):
 | 
					    def __init__(self, quant_kv=False, *args, **kwargs):
 | 
				
			||||||
        super().__init__(*args, **kwargs)
 | 
					        super().__init__(*args, **kwargs)
 | 
				
			||||||
        self.real_kv_len = 0
 | 
					        self.real_kv_len = 0
 | 
				
			||||||
 | 
					        self.quant_kv = quant_kv
 | 
				
			||||||
 | 
					        self.append_kv_func = append_fp8_kv_cache if quant_kv else append_kv_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update_seen_tokens(self, layer_idx, q_len):
 | 
					    def update_seen_tokens(self, layer_idx, q_len):
 | 
				
			||||||
        if layer_idx == 0:
 | 
					        if layer_idx == 0:
 | 
				
			||||||
| 
						 | 
					@ -260,49 +269,62 @@ class DynamicCompressCache(DynamicCache):
 | 
				
			||||||
            self.key_cache.append(key_states_compress)
 | 
					            self.key_cache.append(key_states_compress)
 | 
				
			||||||
            self.value_cache.append(value_states_compress)
 | 
					            self.value_cache.append(value_states_compress)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            k_cache_compressed, v_cache_compressed = init_kv_cache(
 | 
					            if not self.quant_kv:
 | 
				
			||||||
                bsz, num_heads, head_dim,
 | 
					                k_cache_compressed, v_cache_compressed = init_kv_cache(
 | 
				
			||||||
                0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
					                    bsz, num_heads, head_dim,
 | 
				
			||||||
                key_states.dtype, key_states.device
 | 
					                    0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
				
			||||||
            )
 | 
					                    key_states.dtype, key_states.device
 | 
				
			||||||
            k_cache_compressed, v_cache_compressed = append_kv_cache(
 | 
					                )
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                k_cache_compressed, v_cache_compressed = init_fp8_kv_cache(
 | 
				
			||||||
 | 
					                    bsz, num_heads, seq_len, head_dim,
 | 
				
			||||||
 | 
					                    device=key_states.device,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            k_cache_compressed, v_cache_compressed = self.append_kv_func(
 | 
				
			||||||
                k_cache_compressed, v_cache_compressed,
 | 
					                k_cache_compressed, v_cache_compressed,
 | 
				
			||||||
                key_states_compress, value_states_compress)
 | 
					                key_states_compress, value_states_compress)
 | 
				
			||||||
            self.key_cache[layer_idx] = k_cache_compressed
 | 
					            self.key_cache[layer_idx] = k_cache_compressed
 | 
				
			||||||
            self.value_cache[layer_idx] = v_cache_compressed
 | 
					            self.value_cache[layer_idx] = v_cache_compressed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if key_states.stride(2) != head_dim:
 | 
					            if key_states.stride(2) != head_dim:
 | 
				
			||||||
                k_cache, v_cache = init_kv_cache(
 | 
					                if not self.quant_kv:
 | 
				
			||||||
                    bsz, num_heads, head_dim,
 | 
					                    k_cache, v_cache = init_kv_cache(
 | 
				
			||||||
                    0, key_states.size(2),
 | 
					                        bsz, num_heads, head_dim,
 | 
				
			||||||
                    key_states.dtype, key_states.device
 | 
					                        0, key_states.size(2),
 | 
				
			||||||
                )
 | 
					                        key_states.dtype, key_states.device
 | 
				
			||||||
                k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
					                    )
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    k_cache, v_cache = init_fp8_kv_cache(
 | 
				
			||||||
 | 
					                        bsz, num_heads, 0, head_dim, key_states.device
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                k_cache, v_cache = self.append_kv_func(k_cache, v_cache,
 | 
				
			||||||
 | 
					                                                       key_states, value_states)
 | 
				
			||||||
                return k_cache, v_cache
 | 
					                return k_cache, v_cache
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                return key_states, value_states
 | 
					                return key_states, value_states
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            cache_k = self.key_cache[layer_idx]
 | 
					            cache_k = self.key_cache[layer_idx]
 | 
				
			||||||
            cache_v = self.value_cache[layer_idx]
 | 
					            cache_v = self.value_cache[layer_idx]
 | 
				
			||||||
            if not enough_kv_room:
 | 
					            if not enough_kv_room and not self.quant_kv:
 | 
				
			||||||
                # allocate new
 | 
					                # allocate new
 | 
				
			||||||
                new_c_k, new_c_v = extend_kv_cache(bsz,
 | 
					                new_c_k, new_c_v = extend_kv_cache(
 | 
				
			||||||
                                                   num_heads,  # Support GQA
 | 
					                    bsz,
 | 
				
			||||||
                                                   head_dim,
 | 
					                    num_heads,  # Support GQA
 | 
				
			||||||
                                                   cache_k.size(2),
 | 
					                    head_dim,
 | 
				
			||||||
                                                   cache_k.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
					                    cache_k.size(2),
 | 
				
			||||||
                                                   dtype=cache_k.dtype,
 | 
					                    cache_k.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
				
			||||||
                                                   device=query_states.device)
 | 
					                    dtype=cache_k.dtype,
 | 
				
			||||||
 | 
					                    device=query_states.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                new_c_k[:] = cache_k
 | 
					                new_c_k[:] = cache_k
 | 
				
			||||||
                new_c_v[:] = cache_v
 | 
					                new_c_v[:] = cache_v
 | 
				
			||||||
                cache_k = new_c_k
 | 
					                cache_k = new_c_k
 | 
				
			||||||
                cache_v = new_c_v
 | 
					                cache_v = new_c_v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            key_states, value_states = append_kv_cache(cache_k,
 | 
					            key_states, value_states = self.append_kv_func(cache_k,
 | 
				
			||||||
                                                       cache_v,
 | 
					                                                           cache_v,
 | 
				
			||||||
                                                       key_states,
 | 
					                                                           key_states,
 | 
				
			||||||
                                                       value_states)
 | 
					                                                           value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # update past_key_value
 | 
					            # update past_key_value
 | 
				
			||||||
            self.key_cache[layer_idx] = key_states
 | 
					            self.key_cache[layer_idx] = key_states
 | 
				
			||||||
| 
						 | 
					@ -316,3 +338,14 @@ class DynamicCompressCache(DynamicCache):
 | 
				
			||||||
        if len(self.key_cache) <= layer_idx:
 | 
					        if len(self.key_cache) <= layer_idx:
 | 
				
			||||||
            return 0
 | 
					            return 0
 | 
				
			||||||
        return self.real_kv_len
 | 
					        return self.real_kv_len
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
 | 
				
			||||||
 | 
					                          quantize_kv: Optional[bool] = False) -> "DynamicCache":
 | 
				
			||||||
 | 
					        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
 | 
				
			||||||
 | 
					        cache = cls(quantize_kv)
 | 
				
			||||||
 | 
					        if past_key_values is not None:
 | 
				
			||||||
 | 
					            for layer_idx in range(len(past_key_values)):
 | 
				
			||||||
 | 
					                key_states, value_states = past_key_values[layer_idx]
 | 
				
			||||||
 | 
					                cache.update(key_states, value_states, layer_idx)
 | 
				
			||||||
 | 
					        return cache
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -31,6 +31,7 @@
 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
| 
						 | 
					@ -40,11 +41,13 @@ from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half
 | 
				
			||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
 | 
					from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
 | 
					from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
					from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
				
			||||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
 | 
					from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36
 | 
				
			||||||
 | 
					from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Optional, Tuple, List
 | 
					from typing import Optional, Tuple, List
 | 
				
			||||||
from transformers.models.phi.modeling_phi import repeat_kv
 | 
					from transformers.models.phi.modeling_phi import repeat_kv
 | 
				
			||||||
from transformers.cache_utils import Cache
 | 
					from transformers.cache_utils import Cache
 | 
				
			||||||
 | 
					KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
 | 
					def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
 | 
				
			||||||
| 
						 | 
					@ -94,6 +97,9 @@ def attention_forward(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bsz, q_len, _ = hidden_states.size()
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # [CompressKV]
 | 
				
			||||||
 | 
					    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    qkv = self.qkv_proj(hidden_states)
 | 
					    qkv = self.qkv_proj(hidden_states)
 | 
				
			||||||
    qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
 | 
					    qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
 | 
				
			||||||
    qkv = qkv.transpose(1, 2)
 | 
					    qkv = qkv.transpose(1, 2)
 | 
				
			||||||
| 
						 | 
					@ -127,12 +133,26 @@ def attention_forward(
 | 
				
			||||||
                                                        cos, sin, position_ids)
 | 
					                                                        cos, sin, position_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if past_key_value is not None:
 | 
					    if past_key_value is not None:
 | 
				
			||||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					        # [CompressKV]
 | 
				
			||||||
                                                         self.layer_idx, None)
 | 
					        if use_compresskv:
 | 
				
			||||||
 | 
					            enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
				
			||||||
 | 
					            key_states, value_states = past_key_value.update(
 | 
				
			||||||
 | 
					                key_states, value_states, self.layer_idx,
 | 
				
			||||||
 | 
					                query_states, attention_mask, self.num_key_value_groups,
 | 
				
			||||||
 | 
					                self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
				
			||||||
 | 
					                                                             self.layer_idx, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
					    if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
				
			||||||
 | 
					        # [CompressKV]
 | 
				
			||||||
 | 
					        if use_compresskv:
 | 
				
			||||||
 | 
					            # print(attention_mask.shape)
 | 
				
			||||||
 | 
					            context_len = key_states.size(2)
 | 
				
			||||||
 | 
					            attention_mask = attention_mask[:, :, :, -context_len:]
 | 
				
			||||||
        import xe_addons
 | 
					        import xe_addons
 | 
				
			||||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
					        if isinstance(past_key_value,
 | 
				
			||||||
 | 
					                      DynamicFp8Cache) or (use_compresskv and past_key_value.quant_kv):
 | 
				
			||||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
					            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
				
			||||||
                                            attention_mask)
 | 
					                                            attention_mask)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
| 
						 | 
					@ -148,7 +168,8 @@ def attention_forward(
 | 
				
			||||||
    #         attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
					    #         attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
				
			||||||
    #                                            value_states, attention_mask)
 | 
					    #                                            value_states, attention_mask)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
					        if isinstance(past_key_value,
 | 
				
			||||||
 | 
					                      DynamicFp8Cache) or (use_compresskv and past_key_value.quant_kv):
 | 
				
			||||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
					            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
				
			||||||
                                                            query_states.dtype)
 | 
					                                                            query_states.dtype)
 | 
				
			||||||
        # repeat k/v heads if n_kv_heads < n_heads
 | 
					        # repeat k/v heads if n_kv_heads < n_heads
 | 
				
			||||||
| 
						 | 
					@ -235,10 +256,20 @@ def phi3_model_forward_wrapper(origin_model_forward):
 | 
				
			||||||
        use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
					        use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
        input = input_ids if input_ids is not None else inputs_embeds
 | 
					        input = input_ids if input_ids is not None else inputs_embeds
 | 
				
			||||||
        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
 | 
					        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
 | 
				
			||||||
 | 
					        use_compress_kv = should_use_compresskv(input, input.shape[-1])
 | 
				
			||||||
        if use_cache:
 | 
					        if use_cache:
 | 
				
			||||||
            if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
					            if use_compress_kv and not isinstance(past_key_values,
 | 
				
			||||||
 | 
					                                                  DynamicCompressCache):
 | 
				
			||||||
 | 
					                past_key_values = DynamicCompressCache.\
 | 
				
			||||||
 | 
					                    from_legacy_cache(past_key_values,
 | 
				
			||||||
 | 
					                                      quantize_kv=use_quantize_kv)
 | 
				
			||||||
 | 
					            if use_quantize_kv and not isinstance(past_key_values,
 | 
				
			||||||
 | 
					                                                  (DynamicFp8Cache, DynamicCompressCache)):
 | 
				
			||||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
					                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
				
			||||||
            if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
 | 
					            if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
 | 
				
			||||||
 | 
					                                                                              (DynamicNormalCache,
 | 
				
			||||||
 | 
					                                                                               DynamicCompressCache
 | 
				
			||||||
 | 
					                                                                               )):
 | 
				
			||||||
                past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
 | 
					                past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
 | 
				
			||||||
        return origin_model_forward(
 | 
					        return origin_model_forward(
 | 
				
			||||||
            self=self,
 | 
					            self=self,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -490,7 +490,7 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int):
 | 
				
			||||||
    if use_compress_kv is None:
 | 
					    if use_compress_kv is None:
 | 
				
			||||||
        return (
 | 
					        return (
 | 
				
			||||||
            get_xpu_device_type(x) == "mtl"
 | 
					            get_xpu_device_type(x) == "mtl"
 | 
				
			||||||
            and prompt_len >= 2500
 | 
					            and prompt_len >= 1800
 | 
				
			||||||
            and prompt_len <= 4500
 | 
					            and prompt_len <= 4500
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue