diff --git a/python/llm/src/ipex_llm/transformers/kv.py b/python/llm/src/ipex_llm/transformers/kv.py index 0e3803f5..100da837 100644 --- a/python/llm/src/ipex_llm/transformers/kv.py +++ b/python/llm/src/ipex_llm/transformers/kv.py @@ -218,8 +218,6 @@ class DynamicCompressCache(DynamicCache): def __init__(self, quant_kv=False, *args, **kwargs): super().__init__(*args, **kwargs) self.real_kv_len = 0 - self.quant_kv = quant_kv - self.append_kv_func = append_fp8_kv_cache if quant_kv else append_kv_cache def update_seen_tokens(self, layer_idx, q_len): if layer_idx == 0: @@ -266,46 +264,33 @@ class DynamicCompressCache(DynamicCache): value_states=value_states, attention_mask=attention_mask, num_key_value_groups=num_key_value_groups) - self.key_cache.append(key_states_compress) - self.value_cache.append(value_states_compress) - if not self.quant_kv: - k_cache_compressed, v_cache_compressed = init_kv_cache( - bsz, num_heads, head_dim, - 0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH, - key_states.dtype, key_states.device - ) - else: - k_cache_compressed, v_cache_compressed = init_fp8_kv_cache( - bsz, num_heads, seq_len, head_dim, - device=key_states.device, - ) - k_cache_compressed, v_cache_compressed = self.append_kv_func( + k_cache_compressed, v_cache_compressed = init_kv_cache( + bsz, num_heads, head_dim, + 0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH, + key_states.dtype, key_states.device + ) + k_cache_compressed, v_cache_compressed = append_kv_cache( k_cache_compressed, v_cache_compressed, key_states_compress, value_states_compress) - self.key_cache[layer_idx] = k_cache_compressed - self.value_cache[layer_idx] = v_cache_compressed + self.key_cache.append(k_cache_compressed) + self.value_cache.append(v_cache_compressed) if key_states.stride(2) != head_dim: - if not self.quant_kv: - k_cache, v_cache = init_kv_cache( - bsz, num_heads, head_dim, - 0, key_states.size(2), - key_states.dtype, key_states.device - ) - else: - k_cache, v_cache = init_fp8_kv_cache( - bsz, num_heads, 0, head_dim, key_states.device - ) - k_cache, v_cache = self.append_kv_func(k_cache, v_cache, - key_states, value_states) + k_cache, v_cache = init_kv_cache( + bsz, num_heads, head_dim, + 0, key_states.size(2), + key_states.dtype, key_states.device + ) + k_cache, v_cache = append_kv_cache(k_cache, v_cache, + key_states, value_states) return k_cache, v_cache else: return key_states, value_states else: cache_k = self.key_cache[layer_idx] cache_v = self.value_cache[layer_idx] - if not enough_kv_room and not self.quant_kv: + if not enough_kv_room: # allocate new new_c_k, new_c_v = extend_kv_cache( bsz, @@ -321,10 +306,10 @@ class DynamicCompressCache(DynamicCache): cache_k = new_c_k cache_v = new_c_v - key_states, value_states = self.append_kv_func(cache_k, - cache_v, - key_states, - value_states) + key_states, value_states = append_kv_cache(cache_k, + cache_v, + key_states, + value_states) # update past_key_value self.key_cache[layer_idx] = key_states @@ -339,13 +324,74 @@ class DynamicCompressCache(DynamicCache): return 0 return self.real_kv_len - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - quantize_kv: Optional[bool] = False) -> "DynamicCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" - cache = cls(quantize_kv) - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache + +class DynamicCompressFp8Cache(DynamicCompressCache, DynamicFp8Cache): + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + query_states: torch.Tensor, + attention_mask: torch.Tensor, + num_key_value_groups: int, + attn_config: Dict[str, Any], + enough_kv_room: bool, + KV_CACHE_ALLOC_BLOCK_LENGTH: int, + cache_kwargs: Optional[Dict[str, Any]]=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + bsz, num_heads, seq_len, head_dim = key_states.shape + + if layer_idx == 0: + if hasattr(self, "_seen_tokens"): + # 4.39 uses `_seen_tokens` + self._seen_tokens += seq_len + else: + # 4.37 uses `seen_tokens` + self.seen_tokens += seq_len + self.real_kv_len += seq_len + + # Update the cache + if len(self.key_cache) <= layer_idx: + # First token, compress kv cache + key_states_compress, value_states_compress = compress_kv( + attn_config=attn_config, + key_states=key_states, + query_states=query_states, + value_states=value_states, + attention_mask=attention_mask, + num_key_value_groups=num_key_value_groups) + + k_cache_compressed, v_cache_compressed = init_fp8_kv_cache( + bsz, num_heads, seq_len, head_dim, + device=key_states.device, + ) + + k_cache_compressed, v_cache_compressed = append_fp8_kv_cache( + k_cache_compressed, v_cache_compressed, + key_states_compress, value_states_compress) + self.key_cache.append(k_cache_compressed) + self.value_cache.append(v_cache_compressed) + + if key_states.stride(2) != head_dim: + k_cache, v_cache = init_fp8_kv_cache( + bsz, num_heads, 0, head_dim, key_states.device + ) + k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, + key_states, value_states) + return k_cache, v_cache + else: + return key_states, value_states + else: + cache_k = self.key_cache[layer_idx] + cache_v = self.value_cache[layer_idx] + + key_states, value_states = append_fp8_kv_cache(cache_k, + cache_v, + key_states, + value_states) + + # update past_key_value + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + + return self.key_cache[layer_idx], self.value_cache[layer_idx] diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 7a214e57..b1394302 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -17,6 +17,7 @@ # https://huggingface.co/THUDM/chatglm2-6b/blob/8eb45c842594b8473f291d0f94e7bbe86ffc67d8/modeling_chatglm.py # +import os import math import torch from typing import Optional, Tuple @@ -27,7 +28,9 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, u from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \ use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36 -from ipex_llm.transformers.kv import DynamicCompressCache +from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache + +KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -90,9 +93,12 @@ def chatglm2_model_forward( use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1]) use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h, input_ids) - if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values, - DynamicCompressCache): - past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) + if use_compress_kv and not isinstance(past_key_values, + DynamicCompressCache): + if use_quantize_kv: + past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values) + else: + past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or ( @@ -279,15 +285,9 @@ def chatglm2_attention_forward( # IPEX-LLM OPT: kv cache and quantize kv use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states) - if use_quantize_kv or (not use_compresskv): - key_states, value_states = update_past_key_value( - past_key_value, key_states, value_states, - kv_seq_len, use_quantize_kv, hidden_states.device - ) - # past_key_value: [bsz, n_kv_head, seq_len, head_dim] -> [seq_len, bsz, n_kv_head, head_dim] - past_key_value = (key_states.permute(2, 0, 1, 3), - value_states.permute(2, 0, 1, 3)) if use_cache else None - else: + + # [CompressKV] + if use_compresskv: from transformers.configuration_utils import PretrainedConfig self.config = self.config if hasattr(self, "config") else PretrainedConfig() enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, @@ -296,8 +296,16 @@ def chatglm2_attention_forward( key_states, value_states = past_key_value.update( key_states, value_states, self.layer_number - 1, query_states, attention_mask, n_head // n_kv_head, - self.config, enough_kv_room, 256 + self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH ) + else: + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, hidden_states.device + ) + # past_key_value: [bsz, n_kv_head, seq_len, head_dim] -> [seq_len, bsz, n_kv_head, head_dim] + past_key_value = (key_states.permute(2, 0, 1, 3), + value_states.permute(2, 0, 1, 3)) if use_cache else None # IPEX-LLM OPT: sdp attn_weights = None diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index 0103b243..2daffedb 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -17,6 +17,7 @@ # https://huggingface.co/THUDM/chatglm2-6b-32k/blob/main/configuration_chatglm.py # +import os import torch from typing import Optional, Tuple, Union from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value @@ -25,10 +26,12 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \ get_compresskv_attn_mask from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb from ipex_llm.transformers.models.chatglm2 import repeat_kv -from ipex_llm.transformers.kv import DynamicCompressCache +from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache from transformers.modeling_outputs import BaseModelOutputWithPast import math +KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) + def chatglm4_model_forward( self, @@ -54,9 +57,12 @@ def chatglm4_model_forward( use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h, inputs) - if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values, - DynamicCompressCache): - past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) + if use_compress_kv and not isinstance(past_key_values, + DynamicCompressCache): + if use_quantize_kv: + past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values) + else: + past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) if inputs_embeds is None: batch_size, seq_length = input_ids.shape @@ -201,7 +207,19 @@ def chatglm4_attention_forward( # IPEX-LLM OPT: kv cache and quantize kv use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states) - if use_quantize_kv or (not use_compresskv): + # [CompressKV] + if use_compresskv: + from transformers.configuration_utils import PretrainedConfig + self.config = self.config if hasattr(self, "config") else PretrainedConfig() + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, + self.layer_number - 1, + q_len) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_number - 1, + query_states, attention_mask, n_head // n_kv_head, + self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH + ) + else: key_states, value_states = update_past_key_value( past_key_value, key_states, value_states, kv_seq_len, use_quantize_kv, hidden_states.device @@ -214,30 +232,19 @@ def chatglm4_attention_forward( past_key_value = (key_states, value_states) else: past_key_value = None - else: - from transformers.configuration_utils import PretrainedConfig - self.config = self.config if hasattr(self, "config") else PretrainedConfig() - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, - self.layer_number - 1, - q_len) - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_number - 1, - query_states, attention_mask, n_head // n_kv_head, - self.config, enough_kv_room, 256 - ) # IPEX-LLM OPT: sdp attn_weights = None if use_sdp(q_len, kv_seq_len, head_dim, query_states): import xe_addons + if use_compresskv: + attention_mask = get_compresskv_attn_mask(key_states, attention_mask) if use_quantize_kv: attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) else: attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_states, self.training): import xe_addons - if use_compresskv: - attention_mask = get_compresskv_attn_mask(key_states, attention_mask) if use_quantize_kv: attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states, attention_mask) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 6c2680bf..5e633da7 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -120,19 +120,25 @@ def llama_model_forward_4_36( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache + from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache, \ + DynamicCompressFp8Cache use_cache = use_cache if use_cache is not None else self.config.use_cache input = input_ids if input_ids is not None else inputs_embeds if use_cache: - if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input, - self.config.num_attention_heads//self.config.num_key_value_heads): - if not isinstance(past_key_values, DynamicFp8Cache): - past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif should_use_compresskv(input, input.shape[1]): - # if use quantize kv, compress kv will be ignored now + use_quantize = use_quantize_kv_cache( + self.layers[0].mlp.up_proj, input, + self.config.num_attention_heads//self.config.num_key_value_heads) + if should_use_compresskv(input, input.shape[1]): if not isinstance(past_key_values, DynamicCompressCache): - past_key_values = DynamicCompressCache.from_legacy_cache( - past_key_values) + if use_quantize: + past_key_values = DynamicCompressFp8Cache.from_legacy_cache( + past_key_values) + else: + past_key_values = DynamicCompressCache.from_legacy_cache( + past_key_values) + elif use_quantize: + if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) return llama_model_forward_4_36_internal( self=self, input_ids=input_ids, @@ -160,19 +166,25 @@ def llama_model_forward_4_38( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache + from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache, \ + DynamicCompressFp8Cache use_cache = use_cache if use_cache is not None else self.config.use_cache input = input_ids if input_ids is not None else inputs_embeds if use_cache: - if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input, - self.config.num_attention_heads//self.config.num_key_value_heads): - if not isinstance(past_key_values, DynamicFp8Cache): - past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif should_use_compresskv(input, input.shape[1]): - # if use quantize kv, compress kv will be ignored now + use_quantize = use_quantize_kv_cache( + self.layers[0].mlp.up_proj, input, + self.config.num_attention_heads//self.config.num_key_value_heads) + if should_use_compresskv(input, input.shape[1]): if not isinstance(past_key_values, DynamicCompressCache): - past_key_values = DynamicCompressCache.from_legacy_cache( - past_key_values) + if use_quantize: + past_key_values = DynamicCompressFp8Cache.from_legacy_cache( + past_key_values) + else: + past_key_values = DynamicCompressCache.from_legacy_cache( + past_key_values) + elif use_quantize: + if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) return llama_model_forward_4_38_internal( self=self, input_ids=input_ids, @@ -201,19 +213,25 @@ def llama_model_forward_4_41( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache + from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache, \ + DynamicCompressFp8Cache use_cache = use_cache if use_cache is not None else self.config.use_cache input = input_ids if input_ids is not None else inputs_embeds if use_cache: - if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input, - self.config.num_attention_heads//self.config.num_key_value_heads): - if not isinstance(past_key_values, DynamicFp8Cache): - past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif should_use_compresskv(input, input.shape[1]): - # if use quantize kv, compress kv will be ignored now + use_quantize = use_quantize_kv_cache( + self.layers[0].mlp.up_proj, input, + self.config.num_attention_heads//self.config.num_key_value_heads) + if should_use_compresskv(input, input.shape[1]): if not isinstance(past_key_values, DynamicCompressCache): - past_key_values = DynamicCompressCache.from_legacy_cache( - past_key_values) + if use_quantize: + past_key_values = DynamicCompressFp8Cache.from_legacy_cache( + past_key_values) + else: + past_key_values = DynamicCompressCache.from_legacy_cache( + past_key_values) + elif use_quantize: + if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) return llama_model_forward_4_41_internal( self=self, input_ids=input_ids, @@ -1086,6 +1104,7 @@ def llama_attention_forward_4_41_quantized( cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: + from ipex_llm.transformers.kv import DynamicCompressCache if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. " @@ -1102,6 +1121,9 @@ def llama_attention_forward_4_41_quantized( enough_kv_room, bsz * q_len, llama_decoding_fast_path_qtype_check) and no_tp + # [CompressKV] + use_compresskv = isinstance(past_key_value, DynamicCompressCache) + if decoding_fast_path: hidden_states = hidden_states.view(1, -1) tmp_cache_k, tmp_cache_v = init_kv_cache( @@ -1177,8 +1199,15 @@ def llama_attention_forward_4_41_quantized( repeated_value_states = repeat_kv(value_states, self.num_key_value_groups) if use_cache: cache_kwargs = None - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + # [CompressKV] + if use_compresskv: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, + query_states, attention_mask, self.num_key_value_groups, + self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) + else: + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) if use_cache and use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): import xe_addons @@ -1227,8 +1256,15 @@ def llama_attention_forward_4_41_quantized( attn_output = torch.matmul(attn_weights, repeated_value_states) else: cache_kwargs = None # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + # [CompressKV] + if use_compresskv: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, + query_states, attention_mask, self.num_key_value_groups, + self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) + else: + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) kv_seq_len = key_states.shape[-2] if not use_sdp_fp8(q_len, key_states.shape[2], query_states): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, @@ -1275,6 +1311,11 @@ def llama_attention_forward_4_41_quantized( new_attn_mask = attention_mask[:, :, :, 0:kv_seq_len] else: new_attn_mask = attention_mask + + # [CompressKV] + if use_compresskv: + new_attn_mask = get_compresskv_attn_mask(key_states, + new_attn_mask) attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, new_attn_mask) attn_weights = None @@ -1652,6 +1693,7 @@ def llama_attention_forward_4_38_quantized( cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: + from ipex_llm.transformers.kv import DynamicCompressCache if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. " @@ -1668,6 +1710,10 @@ def llama_attention_forward_4_38_quantized( enough_kv_room, bsz * q_len, llama_decoding_fast_path_qtype_check) and no_tp + + # [CompressKV] + use_compresskv = isinstance(past_key_value, DynamicCompressCache) + if decoding_fast_path: hidden_states = hidden_states.view(1, -1) tmp_cache_k, tmp_cache_v = init_kv_cache( @@ -1743,8 +1789,16 @@ def llama_attention_forward_4_38_quantized( repeated_value_states = repeat_kv(value_states, self.num_key_value_groups) if use_cache: cache_kwargs = None - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + # [CompressKV] + if use_compresskv: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, + query_states, attention_mask, self.num_key_value_groups, + self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) + else: + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) + if use_cache and use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): import xe_addons @@ -1793,8 +1847,15 @@ def llama_attention_forward_4_38_quantized( attn_output = torch.matmul(attn_weights, repeated_value_states) else: cache_kwargs = None # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + # [CompressKV] + if use_compresskv: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, + query_states, attention_mask, self.num_key_value_groups, + self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) + else: + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) kv_seq_len = key_states.shape[-2] if not use_sdp_fp8(q_len, key_states.shape[2], query_states): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, @@ -1841,6 +1902,11 @@ def llama_attention_forward_4_38_quantized( new_attn_mask = attention_mask[:, :, kv_seq_len-q_len:kv_seq_len, 0:kv_seq_len] else: new_attn_mask = attention_mask + + # [CompressKV] + if use_compresskv: + new_attn_mask = get_compresskv_attn_mask(key_states, + new_attn_mask) attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, new_attn_mask) attn_weights = None diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm.py b/python/llm/src/ipex_llm/transformers/models/minicpm.py index 926e6247..afbcde6c 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm.py @@ -47,7 +47,8 @@ from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, get_compres from ipex_llm.transformers.models.utils import should_use_compresskv, should_use_fuse_rope from ipex_llm.transformers.models.llama import repeat_kv from ipex_llm.transformers.models.common import merge_qkv_base -from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache +from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, \ + DynamicCompressCache, DynamicCompressFp8Cache from transformers.cache_utils import Cache @@ -79,6 +80,10 @@ def minicpm_attention_forward( self.num_key_value_heads, self.num_key_value_heads], dim=1) + # [CompressKV] + use_compresskv = isinstance(past_key_value, DynamicCompressCache) + use_quantizekv = isinstance(past_key_value, DynamicFp8Cache) + kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) @@ -94,7 +99,7 @@ def minicpm_attention_forward( ) if past_key_value is not None: - if isinstance(past_key_value, DynamicCompressCache): + if use_compresskv: enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, q_len) key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, @@ -107,10 +112,11 @@ def minicpm_attention_forward( attn_weights = None if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): import xe_addons - if isinstance(past_key_value, DynamicCompressCache): + # [CompressKV] + if use_compresskv: attention_mask = get_compresskv_attn_mask(key_states, attention_mask) - attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) - elif isinstance(past_key_value, DynamicFp8Cache): + + if use_quantizekv: attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) else: @@ -118,14 +124,14 @@ def minicpm_attention_forward( attention_mask) elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): + if use_quantizekv: attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states, attention_mask) else: attn_output = xe_addons.sdp_causal(query_states, key_states, value_states, attention_mask) else: - if isinstance(past_key_value, DynamicFp8Cache): + if use_quantizekv: key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -180,11 +186,15 @@ def minicpm_model_forward_wrapper(origin_forward): use_cache = use_cache if use_cache is not None else self.config.use_cache if use_cache: - if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): + if use_compress_kv and not isinstance(past_key_values, + DynamicCompressCache): + if use_quantize_kv: + past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values) + else: + past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) + elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache, + DynamicCompressCache)): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif not use_quantize_kv and use_compress_kv and not isinstance(past_key_values, - DynamicCompressCache): - past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) elif (not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, (DynamicNormalCache, DynamicCompressCache))): past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 04ed59af..6c8e6dfa 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -40,10 +40,11 @@ from torch import nn from ipex_llm.transformers.models.common import attention_softmax from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU -from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal +from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, get_compresskv_attn_mask from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36 -from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache +from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, \ + DynamicCompressCache, DynamicCompressFp8Cache from typing import Optional, Tuple, List from transformers.models.phi.modeling_phi import repeat_kv @@ -100,6 +101,7 @@ def attention_forward( # [CompressKV] use_compresskv = isinstance(past_key_value, DynamicCompressCache) + use_quantizekv = isinstance(past_key_value, DynamicFp8Cache) qkv = self.qkv_proj(hidden_states) qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) @@ -150,12 +152,9 @@ def attention_forward( if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): # [CompressKV] if use_compresskv: - # print(attention_mask.shape) - context_len = key_states.size(2) - attention_mask = attention_mask[:, :, :, -context_len:] + attention_mask = get_compresskv_attn_mask(key_states, attention_mask) import xe_addons - if isinstance(past_key_value, - DynamicFp8Cache) or (use_compresskv and past_key_value.quant_kv): + if use_quantizekv: attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) else: @@ -171,8 +170,7 @@ def attention_forward( # attn_output = xe_addons.sdp_causal(query_states, key_states, # value_states, attention_mask) else: - if isinstance(past_key_value, - DynamicFp8Cache) or (use_compresskv and past_key_value.quant_kv): + if use_quantizekv: key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) # repeat k/v heads if n_kv_heads < n_heads @@ -262,11 +260,12 @@ def phi3_model_forward_wrapper(origin_model_forward): if use_cache: if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache): - past_key_values = DynamicCompressCache.\ - from_legacy_cache(past_key_values, - quantize_kv=use_quantize_kv) - if use_quantize_kv and not isinstance(past_key_values, - (DynamicFp8Cache, DynamicCompressCache)): + if use_quantize_kv: + past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values) + else: + past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) + if use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache, + DynamicCompressCache)): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, (DynamicNormalCache, diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 8368aa58..c01488a6 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -51,7 +51,8 @@ from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \ should_use_compresskv, is_enough_kv_cache_room_4_36, get_compresskv_attn_mask from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal -from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, DynamicCompressCache +from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, \ + DynamicCompressCache, DynamicCompressFp8Cache from ipex_llm.utils.common import invalidInputError from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP @@ -122,11 +123,14 @@ def qwen2_model_forward( use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) if use_cache: - if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): + if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache): + if use_quantize_kv: + past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values) + else: + past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) + elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache, + DynamicCompressCache)): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif not use_quantize_kv and use_compress_kv and not isinstance(past_key_values, - DynamicCompressCache): - past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, (DynamicNormalCache, DynamicCompressCache)): @@ -312,10 +316,20 @@ def qwen2_model_forward_4_42( and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds, self.config.num_attention_heads//self.config.num_key_value_heads) ) + use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1]) + if use_cache: - if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): + if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache): + if use_quantize_kv: + past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values) + else: + past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) + elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache, + DynamicCompressCache)): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache): + elif not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, + (DynamicNormalCache, + DynamicCompressCache)): past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) # ipex-llm changes end @@ -522,6 +536,7 @@ def qwen2_attention_forward( # [CompressKV] from ipex_llm.transformers.kv import DynamicCompressCache use_compresskv = isinstance(past_key_value, DynamicCompressCache) + use_quantizekv = isinstance(past_key_value, DynamicFp8Cache) if hasattr(self, 'qkv_proj') and self.qkv_proj is not None: qkv = self.qkv_proj(hidden_states) @@ -592,7 +607,7 @@ def qwen2_attention_forward( import xe_addons if use_compresskv: attention_mask = get_compresskv_attn_mask(key_states, attention_mask) - if isinstance(past_key_value, DynamicFp8Cache): + if use_quantizekv: attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) else: @@ -600,14 +615,14 @@ def qwen2_attention_forward( attention_mask) elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): + if use_quantizekv: attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states, attention_mask) else: attn_output = xe_addons.sdp_causal(query_states, key_states, value_states, attention_mask) else: - if isinstance(past_key_value, DynamicFp8Cache): + if use_quantizekv: key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) # repeat k/v heads if n_kv_heads < n_heads