parent
3ee194d983
commit
c3c058373f
4 changed files with 33 additions and 27 deletions
|
|
@ -128,7 +128,9 @@ def llama_model_forward_4_36(
|
|||
use_quantize = use_quantize_kv_cache(
|
||||
self.layers[0].mlp.up_proj, input,
|
||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||
if should_use_compresskv(input, input.shape[1]):
|
||||
use_compresskv = should_use_compresskv(input, input.shape[1]) or \
|
||||
isinstance(past_key_values, DynamicCompressCache)
|
||||
if use_compresskv:
|
||||
if not isinstance(past_key_values, DynamicCompressCache):
|
||||
if use_quantize:
|
||||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
|
||||
|
|
@ -137,7 +139,7 @@ def llama_model_forward_4_36(
|
|||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||
past_key_values)
|
||||
elif use_quantize:
|
||||
if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)):
|
||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
return llama_model_forward_4_36_internal(
|
||||
self=self,
|
||||
|
|
@ -174,7 +176,9 @@ def llama_model_forward_4_38(
|
|||
use_quantize = use_quantize_kv_cache(
|
||||
self.layers[0].mlp.up_proj, input,
|
||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||
if should_use_compresskv(input, input.shape[1]):
|
||||
use_compresskv = should_use_compresskv(input, input.shape[1]) or \
|
||||
isinstance(past_key_values, DynamicCompressCache)
|
||||
if use_compresskv:
|
||||
if not isinstance(past_key_values, DynamicCompressCache):
|
||||
if use_quantize:
|
||||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
|
||||
|
|
@ -183,7 +187,7 @@ def llama_model_forward_4_38(
|
|||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||
past_key_values)
|
||||
elif use_quantize:
|
||||
if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)):
|
||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
return llama_model_forward_4_38_internal(
|
||||
self=self,
|
||||
|
|
@ -221,7 +225,9 @@ def llama_model_forward_4_41(
|
|||
use_quantize = use_quantize_kv_cache(
|
||||
self.layers[0].mlp.up_proj, input,
|
||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||
if should_use_compresskv(input, input.shape[1]):
|
||||
use_compresskv = should_use_compresskv(input, input.shape[1]) or \
|
||||
isinstance(past_key_values, DynamicCompressCache)
|
||||
if use_compresskv:
|
||||
if not isinstance(past_key_values, DynamicCompressCache):
|
||||
if use_quantize:
|
||||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
|
||||
|
|
@ -230,7 +236,7 @@ def llama_model_forward_4_41(
|
|||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||
past_key_values)
|
||||
elif use_quantize:
|
||||
if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)):
|
||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
return llama_model_forward_4_41_internal(
|
||||
self=self,
|
||||
|
|
|
|||
|
|
@ -182,7 +182,8 @@ def minicpm_model_forward_wrapper(origin_forward):
|
|||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
||||
self.config.num_attention_heads //
|
||||
self.config.num_key_value_heads)
|
||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
||||
isinstance(past_key_values, DynamicCompressCache)
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
if use_cache:
|
||||
|
|
@ -192,11 +193,11 @@ def minicpm_model_forward_wrapper(origin_forward):
|
|||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||
else:
|
||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||
elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
|
||||
DynamicCompressCache)):
|
||||
elif (use_quantize_kv and not use_compress_kv
|
||||
and not isinstance(past_key_values, DynamicFp8Cache)):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
elif (not use_quantize_kv and not use_compress_kv
|
||||
and not isinstance(past_key_values, (DynamicNormalCache, DynamicCompressCache))):
|
||||
and not isinstance(past_key_values, DynamicNormalCache)):
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||
# ipex-llm changes end
|
||||
return origin_forward(
|
||||
|
|
|
|||
|
|
@ -256,7 +256,8 @@ def phi3_model_forward_wrapper(origin_model_forward):
|
|||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
|
||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
||||
isinstance(past_key_values, DynamicCompressCache)
|
||||
if use_cache:
|
||||
if use_compress_kv and not isinstance(past_key_values,
|
||||
DynamicCompressCache):
|
||||
|
|
@ -264,13 +265,11 @@ def phi3_model_forward_wrapper(origin_model_forward):
|
|||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||
else:
|
||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||
if use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
|
||||
DynamicCompressCache)):
|
||||
if use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||
DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||
(DynamicNormalCache,
|
||||
DynamicCompressCache
|
||||
)):
|
||||
DynamicNormalCache):
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||
if past_key_values.get_seq_length() == 0:
|
||||
n_layer = self.config.num_hidden_layers
|
||||
|
|
|
|||
|
|
@ -120,7 +120,8 @@ def qwen2_model_forward(
|
|||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||
)
|
||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
||||
isinstance(past_key_values, DynamicCompressCache)
|
||||
|
||||
if use_cache:
|
||||
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
|
||||
|
|
@ -128,12 +129,11 @@ def qwen2_model_forward(
|
|||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||
else:
|
||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||
elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
|
||||
DynamicCompressCache)):
|
||||
elif use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||
DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||
(DynamicNormalCache,
|
||||
DynamicCompressCache)):
|
||||
DynamicNormalCache):
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
# ipex-llm changes end
|
||||
|
|
@ -316,7 +316,8 @@ def qwen2_model_forward_4_42(
|
|||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds,
|
||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||
)
|
||||
use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1])
|
||||
use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1]) or \
|
||||
isinstance(past_key_values, DynamicCompressCache)
|
||||
|
||||
if use_cache:
|
||||
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
|
||||
|
|
@ -324,12 +325,11 @@ def qwen2_model_forward_4_42(
|
|||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||
else:
|
||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||
elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
|
||||
DynamicCompressCache)):
|
||||
elif use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||
DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
elif not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||
(DynamicNormalCache,
|
||||
DynamicCompressCache)):
|
||||
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||
DynamicNormalCache):
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||
# ipex-llm changes end
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue