MiniCPM-V support compresskv (#11779)

* fix check error

* fix other models

* remove print
This commit is contained in:
Yina Chen 2024-08-13 14:03:40 +03:00 committed by GitHub
parent 3998de14f0
commit 7cd6ec9723
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 9 additions and 9 deletions

View file

@ -87,7 +87,7 @@ def chatglm2_model_forward(
dtype=inputs_embeds.dtype, device=inputs_embeds.device) dtype=inputs_embeds.dtype, device=inputs_embeds.device)
if use_cache: if use_cache:
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[-1]) use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h, use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
input_ids) input_ids)
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values, if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,

View file

@ -51,7 +51,7 @@ def chatglm4_model_forward(
if use_cache: if use_cache:
inputs = input_ids if input_ids is not None else inputs_embeds inputs = input_ids if input_ids is not None else inputs_embeds
use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1]) use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h, use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
inputs) inputs)
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values, if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,

View file

@ -128,7 +128,7 @@ def llama_model_forward_4_36(
self.config.num_attention_heads//self.config.num_key_value_heads): self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache): if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input, input.shape[-1]): elif should_use_compresskv(input, input.shape[1]):
# if use quantize kv, compress kv will be ignored now # if use quantize kv, compress kv will be ignored now
if not isinstance(past_key_values, DynamicCompressCache): if not isinstance(past_key_values, DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache( past_key_values = DynamicCompressCache.from_legacy_cache(
@ -168,7 +168,7 @@ def llama_model_forward_4_38(
self.config.num_attention_heads//self.config.num_key_value_heads): self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache): if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input, input.shape[-1]): elif should_use_compresskv(input, input.shape[1]):
# if use quantize kv, compress kv will be ignored now # if use quantize kv, compress kv will be ignored now
if not isinstance(past_key_values, DynamicCompressCache): if not isinstance(past_key_values, DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache( past_key_values = DynamicCompressCache.from_legacy_cache(
@ -209,7 +209,7 @@ def llama_model_forward_4_41(
self.config.num_attention_heads//self.config.num_key_value_heads): self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache): if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input, input.shape[-1]): elif should_use_compresskv(input, input.shape[1]):
# if use quantize kv, compress kv will be ignored now # if use quantize kv, compress kv will be ignored now
if not isinstance(past_key_values, DynamicCompressCache): if not isinstance(past_key_values, DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache( past_key_values = DynamicCompressCache.from_legacy_cache(

View file

@ -628,7 +628,7 @@ def minicpm_model_forward(
self.config.num_key_value_heads): self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache): if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input, input.shape[-1]): elif should_use_compresskv(input, input.shape[1]):
if not isinstance(past_key_values, DynamicCompressCache): if not isinstance(past_key_values, DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)

View file

@ -211,7 +211,7 @@ def mistral_model_forward_4_36(
self.config.num_attention_heads//self.config.num_key_value_heads): self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache): if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input_ids, input_ids.shape[-1]): elif should_use_compresskv(input_ids, input_ids.shape[1]):
# if use quantize kv, compress kv will be ignored now # if use quantize kv, compress kv will be ignored now
if not isinstance(past_key_values, DynamicCompressCache): if not isinstance(past_key_values, DynamicCompressCache):
past_key_values = DynamicCompressCache.from_legacy_cache( past_key_values = DynamicCompressCache.from_legacy_cache(

View file

@ -258,7 +258,7 @@ def phi3_model_forward_wrapper(origin_model_forward):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds input = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input) use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
use_compress_kv = should_use_compresskv(input, input.shape[-1]) use_compress_kv = should_use_compresskv(input, input.shape[1])
if use_cache: if use_cache:
if use_compress_kv and not isinstance(past_key_values, if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache): DynamicCompressCache):

View file

@ -118,7 +118,7 @@ def qwen2_model_forward(
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads//self.config.num_key_value_heads) self.config.num_attention_heads//self.config.num_key_value_heads)
) )
use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1]) use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
if use_cache: if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):