diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 5e633da7..2c9c17e7 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -128,7 +128,9 @@ def llama_model_forward_4_36( use_quantize = use_quantize_kv_cache( self.layers[0].mlp.up_proj, input, self.config.num_attention_heads//self.config.num_key_value_heads) - if should_use_compresskv(input, input.shape[1]): + use_compresskv = should_use_compresskv(input, input.shape[1]) or \ + isinstance(past_key_values, DynamicCompressCache) + if use_compresskv: if not isinstance(past_key_values, DynamicCompressCache): if use_quantize: past_key_values = DynamicCompressFp8Cache.from_legacy_cache( @@ -137,7 +139,7 @@ def llama_model_forward_4_36( past_key_values = DynamicCompressCache.from_legacy_cache( past_key_values) elif use_quantize: - if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)): + if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) return llama_model_forward_4_36_internal( self=self, @@ -174,7 +176,9 @@ def llama_model_forward_4_38( use_quantize = use_quantize_kv_cache( self.layers[0].mlp.up_proj, input, self.config.num_attention_heads//self.config.num_key_value_heads) - if should_use_compresskv(input, input.shape[1]): + use_compresskv = should_use_compresskv(input, input.shape[1]) or \ + isinstance(past_key_values, DynamicCompressCache) + if use_compresskv: if not isinstance(past_key_values, DynamicCompressCache): if use_quantize: past_key_values = DynamicCompressFp8Cache.from_legacy_cache( @@ -183,7 +187,7 @@ def llama_model_forward_4_38( past_key_values = DynamicCompressCache.from_legacy_cache( past_key_values) elif use_quantize: - if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)): + if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) return llama_model_forward_4_38_internal( self=self, @@ -221,7 +225,9 @@ def llama_model_forward_4_41( use_quantize = use_quantize_kv_cache( self.layers[0].mlp.up_proj, input, self.config.num_attention_heads//self.config.num_key_value_heads) - if should_use_compresskv(input, input.shape[1]): + use_compresskv = should_use_compresskv(input, input.shape[1]) or \ + isinstance(past_key_values, DynamicCompressCache) + if use_compresskv: if not isinstance(past_key_values, DynamicCompressCache): if use_quantize: past_key_values = DynamicCompressFp8Cache.from_legacy_cache( @@ -230,7 +236,7 @@ def llama_model_forward_4_41( past_key_values = DynamicCompressCache.from_legacy_cache( past_key_values) elif use_quantize: - if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)): + if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) return llama_model_forward_4_41_internal( self=self, diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm.py b/python/llm/src/ipex_llm/transformers/models/minicpm.py index afbcde6c..d248c507 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm.py @@ -182,7 +182,8 @@ def minicpm_model_forward_wrapper(origin_forward): use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, self.config.num_attention_heads // self.config.num_key_value_heads) - use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) + use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \ + isinstance(past_key_values, DynamicCompressCache) use_cache = use_cache if use_cache is not None else self.config.use_cache if use_cache: @@ -192,11 +193,11 @@ def minicpm_model_forward_wrapper(origin_forward): past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values) else: past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) - elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache, - DynamicCompressCache)): + elif (use_quantize_kv and not use_compress_kv + and not isinstance(past_key_values, DynamicFp8Cache)): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) elif (not use_quantize_kv and not use_compress_kv - and not isinstance(past_key_values, (DynamicNormalCache, DynamicCompressCache))): + and not isinstance(past_key_values, DynamicNormalCache)): past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) # ipex-llm changes end return origin_forward( diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 823fb103..bfa380c2 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -256,7 +256,8 @@ def phi3_model_forward_wrapper(origin_model_forward): use_cache = use_cache if use_cache is not None else self.config.use_cache inputs = input_ids if input_ids is not None else inputs_embeds use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs) - use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) + use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \ + isinstance(past_key_values, DynamicCompressCache) if use_cache: if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache): @@ -264,13 +265,11 @@ def phi3_model_forward_wrapper(origin_model_forward): past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values) else: past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) - if use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache, - DynamicCompressCache)): + if use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, + DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, - (DynamicNormalCache, - DynamicCompressCache - )): + DynamicNormalCache): past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) if past_key_values.get_seq_length() == 0: n_layer = self.config.num_hidden_layers diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index c01488a6..802c5e7e 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -120,7 +120,8 @@ def qwen2_model_forward( and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, self.config.num_attention_heads//self.config.num_key_value_heads) ) - use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) + use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \ + isinstance(past_key_values, DynamicCompressCache) if use_cache: if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache): @@ -128,12 +129,11 @@ def qwen2_model_forward( past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values) else: past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) - elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache, - DynamicCompressCache)): + elif use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, + DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, - (DynamicNormalCache, - DynamicCompressCache)): + DynamicNormalCache): past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) # ipex-llm changes end @@ -316,7 +316,8 @@ def qwen2_model_forward_4_42( and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds, self.config.num_attention_heads//self.config.num_key_value_heads) ) - use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1]) + use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1]) or \ + isinstance(past_key_values, DynamicCompressCache) if use_cache: if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache): @@ -324,12 +325,11 @@ def qwen2_model_forward_4_42( past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values) else: past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) - elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache, - DynamicCompressCache)): + elif use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, + DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, - (DynamicNormalCache, - DynamicCompressCache)): + if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, + DynamicNormalCache): past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) # ipex-llm changes end