MiniCPM-V support compresskv (#11779)
* fix check error * fix other models * remove print
This commit is contained in:
parent
3998de14f0
commit
7cd6ec9723
7 changed files with 9 additions and 9 deletions
|
|
@ -87,7 +87,7 @@ def chatglm2_model_forward(
|
||||||
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[-1])
|
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
|
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
|
||||||
input_ids)
|
input_ids)
|
||||||
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
|
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ def chatglm4_model_forward(
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1])
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
|
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
|
||||||
inputs)
|
inputs)
|
||||||
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
|
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
|
||||||
|
|
|
||||||
|
|
@ -128,7 +128,7 @@ def llama_model_forward_4_36(
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads):
|
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif should_use_compresskv(input, input.shape[-1]):
|
elif should_use_compresskv(input, input.shape[1]):
|
||||||
# if use quantize kv, compress kv will be ignored now
|
# if use quantize kv, compress kv will be ignored now
|
||||||
if not isinstance(past_key_values, DynamicCompressCache):
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||||
|
|
@ -168,7 +168,7 @@ def llama_model_forward_4_38(
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads):
|
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif should_use_compresskv(input, input.shape[-1]):
|
elif should_use_compresskv(input, input.shape[1]):
|
||||||
# if use quantize kv, compress kv will be ignored now
|
# if use quantize kv, compress kv will be ignored now
|
||||||
if not isinstance(past_key_values, DynamicCompressCache):
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||||
|
|
@ -209,7 +209,7 @@ def llama_model_forward_4_41(
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads):
|
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif should_use_compresskv(input, input.shape[-1]):
|
elif should_use_compresskv(input, input.shape[1]):
|
||||||
# if use quantize kv, compress kv will be ignored now
|
# if use quantize kv, compress kv will be ignored now
|
||||||
if not isinstance(past_key_values, DynamicCompressCache):
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||||
|
|
|
||||||
|
|
@ -628,7 +628,7 @@ def minicpm_model_forward(
|
||||||
self.config.num_key_value_heads):
|
self.config.num_key_value_heads):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif should_use_compresskv(input, input.shape[-1]):
|
elif should_use_compresskv(input, input.shape[1]):
|
||||||
if not isinstance(past_key_values, DynamicCompressCache):
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -211,7 +211,7 @@ def mistral_model_forward_4_36(
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads):
|
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif should_use_compresskv(input_ids, input_ids.shape[-1]):
|
elif should_use_compresskv(input_ids, input_ids.shape[1]):
|
||||||
# if use quantize kv, compress kv will be ignored now
|
# if use quantize kv, compress kv will be ignored now
|
||||||
if not isinstance(past_key_values, DynamicCompressCache):
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||||
|
|
|
||||||
|
|
@ -258,7 +258,7 @@ def phi3_model_forward_wrapper(origin_model_forward):
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
input = input_ids if input_ids is not None else inputs_embeds
|
input = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
|
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
|
||||||
use_compress_kv = should_use_compresskv(input, input.shape[-1])
|
use_compress_kv = should_use_compresskv(input, input.shape[1])
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_compress_kv and not isinstance(past_key_values,
|
if use_compress_kv and not isinstance(past_key_values,
|
||||||
DynamicCompressCache):
|
DynamicCompressCache):
|
||||||
|
|
|
||||||
|
|
@ -118,7 +118,7 @@ def qwen2_model_forward(
|
||||||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||||
)
|
)
|
||||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1])
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue