diff --git a/python/llm/src/ipex_llm/transformers/kv.py b/python/llm/src/ipex_llm/transformers/kv.py index 97b345cd..b0e52282 100644 --- a/python/llm/src/ipex_llm/transformers/kv.py +++ b/python/llm/src/ipex_llm/transformers/kv.py @@ -356,6 +356,22 @@ class DynamicCompressCache(DynamicCache): return 0 return self.real_kv_len + @classmethod + def from_legacy_cache( + cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + num_hidden_layers: int = None + ) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + cache = cls(num_hidden_layers) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + invalidInputError( + len(key_states) == 0 and len(value_states) == 0, + "from_legacy_cache should be called with an empty kv cache.") + return cache + class DynamicCompressFp8Cache(DynamicCompressCache, DynamicFp8Cache): def update( diff --git a/python/llm/src/ipex_llm/transformers/models/llama32.py b/python/llm/src/ipex_llm/transformers/models/llama32.py index cad33744..e6a9c53b 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama32.py +++ b/python/llm/src/ipex_llm/transformers/models/llama32.py @@ -50,7 +50,10 @@ from ipex_llm.transformers.models.common import attention_softmax from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache -from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache +from ipex_llm.transformers.models.utils import should_use_compresskv, \ + is_enough_kv_cache_room_4_36 +from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache, \ + DynamicCompressFp8Cache def llama_model_forward( @@ -83,11 +86,25 @@ def llama_model_forward( self.layers[0].mlp.down_proj, inputs, self.config.num_attention_heads // self.config.num_key_value_heads ) + use_compresskv = should_use_compresskv(inputs, inputs.shape[1]) or \ + isinstance(past_key_values, DynamicCompressCache) + # disable llama3.2 1b for prefill performance and output quality + use_compresskv = use_compresskv and self.config.hidden_size != 2048 if use_cache: - if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): + if use_compresskv and not isinstance(past_key_values, DynamicCompressCache): + if use_quantize_kv: + past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values) + else: + past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) + elif use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache): + elif ( + not use_quantize_kv + and not use_compresskv + and not isinstance(past_key_values, DynamicNormalCache) + ): past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) + # IPEX-LLM OPT end return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -182,6 +199,9 @@ def llama_attention_forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() + # [CompressKV] + use_compresskv = isinstance(past_key_value, DynamicCompressCache) + qkv = self.qkv_proj(hidden_states) qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) qkv = qkv.transpose(1, 2) @@ -201,8 +221,17 @@ def llama_attention_forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, None) + # [CompressKV] + if use_compresskv: + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, + q_len) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, + query_states, attention_mask, self.num_key_value_groups, + self.config, enough_kv_room, 256) + else: + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, None) kv_seq_len = key_states.size(2) if attention_mask is not None: # no matter the length, we just slice it