parent
3ee194d983
commit
c3c058373f
4 changed files with 33 additions and 27 deletions
|
|
@ -128,7 +128,9 @@ def llama_model_forward_4_36(
|
||||||
use_quantize = use_quantize_kv_cache(
|
use_quantize = use_quantize_kv_cache(
|
||||||
self.layers[0].mlp.up_proj, input,
|
self.layers[0].mlp.up_proj, input,
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||||
if should_use_compresskv(input, input.shape[1]):
|
use_compresskv = should_use_compresskv(input, input.shape[1]) or \
|
||||||
|
isinstance(past_key_values, DynamicCompressCache)
|
||||||
|
if use_compresskv:
|
||||||
if not isinstance(past_key_values, DynamicCompressCache):
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
if use_quantize:
|
if use_quantize:
|
||||||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
|
||||||
|
|
@ -137,7 +139,7 @@ def llama_model_forward_4_36(
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||||
past_key_values)
|
past_key_values)
|
||||||
elif use_quantize:
|
elif use_quantize:
|
||||||
if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
return llama_model_forward_4_36_internal(
|
return llama_model_forward_4_36_internal(
|
||||||
self=self,
|
self=self,
|
||||||
|
|
@ -174,7 +176,9 @@ def llama_model_forward_4_38(
|
||||||
use_quantize = use_quantize_kv_cache(
|
use_quantize = use_quantize_kv_cache(
|
||||||
self.layers[0].mlp.up_proj, input,
|
self.layers[0].mlp.up_proj, input,
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||||
if should_use_compresskv(input, input.shape[1]):
|
use_compresskv = should_use_compresskv(input, input.shape[1]) or \
|
||||||
|
isinstance(past_key_values, DynamicCompressCache)
|
||||||
|
if use_compresskv:
|
||||||
if not isinstance(past_key_values, DynamicCompressCache):
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
if use_quantize:
|
if use_quantize:
|
||||||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
|
||||||
|
|
@ -183,7 +187,7 @@ def llama_model_forward_4_38(
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||||
past_key_values)
|
past_key_values)
|
||||||
elif use_quantize:
|
elif use_quantize:
|
||||||
if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
return llama_model_forward_4_38_internal(
|
return llama_model_forward_4_38_internal(
|
||||||
self=self,
|
self=self,
|
||||||
|
|
@ -221,7 +225,9 @@ def llama_model_forward_4_41(
|
||||||
use_quantize = use_quantize_kv_cache(
|
use_quantize = use_quantize_kv_cache(
|
||||||
self.layers[0].mlp.up_proj, input,
|
self.layers[0].mlp.up_proj, input,
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||||
if should_use_compresskv(input, input.shape[1]):
|
use_compresskv = should_use_compresskv(input, input.shape[1]) or \
|
||||||
|
isinstance(past_key_values, DynamicCompressCache)
|
||||||
|
if use_compresskv:
|
||||||
if not isinstance(past_key_values, DynamicCompressCache):
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
if use_quantize:
|
if use_quantize:
|
||||||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
|
||||||
|
|
@ -230,7 +236,7 @@ def llama_model_forward_4_41(
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||||
past_key_values)
|
past_key_values)
|
||||||
elif use_quantize:
|
elif use_quantize:
|
||||||
if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
return llama_model_forward_4_41_internal(
|
return llama_model_forward_4_41_internal(
|
||||||
self=self,
|
self=self,
|
||||||
|
|
|
||||||
|
|
@ -182,7 +182,8 @@ def minicpm_model_forward_wrapper(origin_forward):
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
||||||
self.config.num_attention_heads //
|
self.config.num_attention_heads //
|
||||||
self.config.num_key_value_heads)
|
self.config.num_key_value_heads)
|
||||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
||||||
|
isinstance(past_key_values, DynamicCompressCache)
|
||||||
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
|
@ -192,11 +193,11 @@ def minicpm_model_forward_wrapper(origin_forward):
|
||||||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||||
else:
|
else:
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||||
elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
|
elif (use_quantize_kv and not use_compress_kv
|
||||||
DynamicCompressCache)):
|
and not isinstance(past_key_values, DynamicFp8Cache)):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif (not use_quantize_kv and not use_compress_kv
|
elif (not use_quantize_kv and not use_compress_kv
|
||||||
and not isinstance(past_key_values, (DynamicNormalCache, DynamicCompressCache))):
|
and not isinstance(past_key_values, DynamicNormalCache)):
|
||||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||||
# ipex-llm changes end
|
# ipex-llm changes end
|
||||||
return origin_forward(
|
return origin_forward(
|
||||||
|
|
|
||||||
|
|
@ -256,7 +256,8 @@ def phi3_model_forward_wrapper(origin_model_forward):
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
|
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
|
||||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
||||||
|
isinstance(past_key_values, DynamicCompressCache)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_compress_kv and not isinstance(past_key_values,
|
if use_compress_kv and not isinstance(past_key_values,
|
||||||
DynamicCompressCache):
|
DynamicCompressCache):
|
||||||
|
|
@ -264,13 +265,11 @@ def phi3_model_forward_wrapper(origin_model_forward):
|
||||||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||||
else:
|
else:
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||||
if use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
|
if use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||||
DynamicCompressCache)):
|
DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||||
(DynamicNormalCache,
|
DynamicNormalCache):
|
||||||
DynamicCompressCache
|
|
||||||
)):
|
|
||||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||||
if past_key_values.get_seq_length() == 0:
|
if past_key_values.get_seq_length() == 0:
|
||||||
n_layer = self.config.num_hidden_layers
|
n_layer = self.config.num_hidden_layers
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,8 @@ def qwen2_model_forward(
|
||||||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||||
)
|
)
|
||||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
||||||
|
isinstance(past_key_values, DynamicCompressCache)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
|
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
|
||||||
|
|
@ -128,12 +129,11 @@ def qwen2_model_forward(
|
||||||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||||
else:
|
else:
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||||
elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
|
elif use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||||
DynamicCompressCache)):
|
DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||||
(DynamicNormalCache,
|
DynamicNormalCache):
|
||||||
DynamicCompressCache)):
|
|
||||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||||
# ipex-llm changes end
|
# ipex-llm changes end
|
||||||
|
|
@ -316,7 +316,8 @@ def qwen2_model_forward_4_42(
|
||||||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds,
|
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds,
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||||
)
|
)
|
||||||
use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1])
|
use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1]) or \
|
||||||
|
isinstance(past_key_values, DynamicCompressCache)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
|
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
|
||||||
|
|
@ -324,12 +325,11 @@ def qwen2_model_forward_4_42(
|
||||||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||||
else:
|
else:
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||||
elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
|
elif use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||||
DynamicCompressCache)):
|
DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||||
(DynamicNormalCache,
|
DynamicNormalCache):
|
||||||
DynamicCompressCache)):
|
|
||||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||||
# ipex-llm changes end
|
# ipex-llm changes end
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue