Fix MTL 4k input qwen2 compresskv error (#11734)

* fix

* fix style
This commit is contained in:
Yina Chen 2024-08-07 11:21:57 +03:00 committed by GitHub
parent a71ae7c22b
commit d2abc9711b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View file

@ -152,7 +152,6 @@ def compress_kv(attn_config, key_states, query_states, value_states, attention_m
if not hasattr(attn_config, 'pooling'): if not hasattr(attn_config, 'pooling'):
attn_config.pooling = 'maxpool' attn_config.pooling = 'maxpool'
bsz, num_heads, q_len, head_dim = query_states.shape bsz, num_heads, q_len, head_dim = query_states.shape
print(f"attn_config.max_capacity_prompt: ", attn_config.max_capacity_prompt, " ", q_len)
if q_len <= attn_config.max_capacity_prompt: if q_len <= attn_config.max_capacity_prompt:
return key_states, value_states return key_states, value_states
else: else:

View file

@ -127,7 +127,8 @@ def qwen2_model_forward(
DynamicCompressCache): DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
DynamicNormalCache): (DynamicNormalCache,
DynamicCompressCache)):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length) past_key_values_length = past_key_values.get_usable_length(seq_length)
# ipex-llm changes end # ipex-llm changes end