Update compresskv model forward type logic (#11868)

* update

* fix
This commit is contained in:
Yina Chen 2024-08-20 13:11:37 +03:00 committed by GitHub
parent 3ee194d983
commit c3c058373f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 33 additions and 27 deletions

View file

@ -128,7 +128,9 @@ def llama_model_forward_4_36(
use_quantize = use_quantize_kv_cache(
self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads//self.config.num_key_value_heads)
if should_use_compresskv(input, input.shape[1]):
use_compresskv = should_use_compresskv(input, input.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
if use_compresskv:
if not isinstance(past_key_values, DynamicCompressCache):
if use_quantize:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
@ -137,7 +139,7 @@ def llama_model_forward_4_36(
past_key_values = DynamicCompressCache.from_legacy_cache(
past_key_values)
elif use_quantize:
if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
return llama_model_forward_4_36_internal(
self=self,
@ -174,7 +176,9 @@ def llama_model_forward_4_38(
use_quantize = use_quantize_kv_cache(
self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads//self.config.num_key_value_heads)
if should_use_compresskv(input, input.shape[1]):
use_compresskv = should_use_compresskv(input, input.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
if use_compresskv:
if not isinstance(past_key_values, DynamicCompressCache):
if use_quantize:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
@ -183,7 +187,7 @@ def llama_model_forward_4_38(
past_key_values = DynamicCompressCache.from_legacy_cache(
past_key_values)
elif use_quantize:
if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
return llama_model_forward_4_38_internal(
self=self,
@ -221,7 +225,9 @@ def llama_model_forward_4_41(
use_quantize = use_quantize_kv_cache(
self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads//self.config.num_key_value_heads)
if should_use_compresskv(input, input.shape[1]):
use_compresskv = should_use_compresskv(input, input.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
if use_compresskv:
if not isinstance(past_key_values, DynamicCompressCache):
if use_quantize:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
@ -230,7 +236,7 @@ def llama_model_forward_4_41(
past_key_values = DynamicCompressCache.from_legacy_cache(
past_key_values)
elif use_quantize:
if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
return llama_model_forward_4_41_internal(
self=self,

View file

@ -182,7 +182,8 @@ def minicpm_model_forward_wrapper(origin_forward):
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads //
self.config.num_key_value_heads)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache:
@ -192,11 +193,11 @@ def minicpm_model_forward_wrapper(origin_forward):
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
DynamicCompressCache)):
elif (use_quantize_kv and not use_compress_kv
and not isinstance(past_key_values, DynamicFp8Cache)):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif (not use_quantize_kv and not use_compress_kv
and not isinstance(past_key_values, (DynamicNormalCache, DynamicCompressCache))):
and not isinstance(past_key_values, DynamicNormalCache)):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
# ipex-llm changes end
return origin_forward(

View file

@ -256,7 +256,8 @@ def phi3_model_forward_wrapper(origin_model_forward):
use_cache = use_cache if use_cache is not None else self.config.use_cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
if use_cache:
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
@ -264,13 +265,11 @@ def phi3_model_forward_wrapper(origin_model_forward):
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
if use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
DynamicCompressCache)):
if use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
(DynamicNormalCache,
DynamicCompressCache
)):
DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
if past_key_values.get_seq_length() == 0:
n_layer = self.config.num_hidden_layers

View file

@ -120,7 +120,8 @@ def qwen2_model_forward(
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads//self.config.num_key_value_heads)
)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
if use_cache:
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
@ -128,12 +129,11 @@ def qwen2_model_forward(
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
DynamicCompressCache)):
elif use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
(DynamicNormalCache,
DynamicCompressCache)):
DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
# ipex-llm changes end
@ -316,7 +316,8 @@ def qwen2_model_forward_4_42(
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds,
self.config.num_attention_heads//self.config.num_key_value_heads)
)
use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1])
use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
if use_cache:
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
@ -324,12 +325,11 @@ def qwen2_model_forward_4_42(
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
DynamicCompressCache)):
elif use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
(DynamicNormalCache,
DynamicCompressCache)):
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
# ipex-llm changes end