From d2abc9711b34ae6f9abe31a0b86d75494ac2bbc3 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Wed, 7 Aug 2024 11:21:57 +0300 Subject: [PATCH] Fix MTL 4k input qwen2 compresskv error (#11734) * fix * fix style --- python/llm/src/ipex_llm/transformers/kv.py | 1 - python/llm/src/ipex_llm/transformers/models/qwen2.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/kv.py b/python/llm/src/ipex_llm/transformers/kv.py index e2f386ea..1543ab34 100644 --- a/python/llm/src/ipex_llm/transformers/kv.py +++ b/python/llm/src/ipex_llm/transformers/kv.py @@ -152,7 +152,6 @@ def compress_kv(attn_config, key_states, query_states, value_states, attention_m if not hasattr(attn_config, 'pooling'): attn_config.pooling = 'maxpool' bsz, num_heads, q_len, head_dim = query_states.shape - print(f"attn_config.max_capacity_prompt: ", attn_config.max_capacity_prompt, " ", q_len) if q_len <= attn_config.max_capacity_prompt: return key_states, value_states else: diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 5a56add8..91335574 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -127,7 +127,8 @@ def qwen2_model_forward( DynamicCompressCache): past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, - DynamicNormalCache): + (DynamicNormalCache, + DynamicCompressCache)): past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) # ipex-llm changes end