Support minicpm compresskv & modify default compresskv config & default enable compresskv on mtl 2.5k~4.5k (#11726)
* support minicpm & modify default & default enable on mtl 2.5k~4.5k * fix style
This commit is contained in:
parent
c093f7d980
commit
a71ae7c22b
8 changed files with 143 additions and 89 deletions
|
|
@ -146,13 +146,14 @@ def compress_kv(attn_config, key_states, query_states, value_states, attention_m
|
||||||
if not hasattr(attn_config, 'window_size'):
|
if not hasattr(attn_config, 'window_size'):
|
||||||
attn_config.window_size = 32
|
attn_config.window_size = 32
|
||||||
if not hasattr(attn_config, 'max_capacity_prompt'):
|
if not hasattr(attn_config, 'max_capacity_prompt'):
|
||||||
attn_config.max_capacity_prompt = 512
|
attn_config.max_capacity_prompt = 1024
|
||||||
if not hasattr(attn_config, 'kernel_size'):
|
if not hasattr(attn_config, 'kernel_size'):
|
||||||
attn_config.kernel_size = 5
|
attn_config.kernel_size = 7
|
||||||
if not hasattr(attn_config, 'pooling'):
|
if not hasattr(attn_config, 'pooling'):
|
||||||
attn_config.pooling = 'avgpool'
|
attn_config.pooling = 'maxpool'
|
||||||
bsz, num_heads, q_len, head_dim = query_states.shape
|
bsz, num_heads, q_len, head_dim = query_states.shape
|
||||||
if q_len < attn_config.max_capacity_prompt:
|
print(f"attn_config.max_capacity_prompt: ", attn_config.max_capacity_prompt, " ", q_len)
|
||||||
|
if q_len <= attn_config.max_capacity_prompt:
|
||||||
return key_states, value_states
|
return key_states, value_states
|
||||||
else:
|
else:
|
||||||
key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device)
|
key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device)
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,7 @@ def chatglm2_model_forward(
|
||||||
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
use_compress_kv = should_use_compresskv(input_ids)
|
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[-1])
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
|
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
|
||||||
input_ids)
|
input_ids)
|
||||||
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
|
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ def chatglm4_model_forward(
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_compress_kv = should_use_compresskv(inputs)
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1])
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
|
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
|
||||||
inputs)
|
inputs)
|
||||||
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
|
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
|
||||||
|
|
|
||||||
|
|
@ -122,7 +122,7 @@ def llama_model_forward_4_36(
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads):
|
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif should_use_compresskv(input):
|
elif should_use_compresskv(input, input.shape[-1]):
|
||||||
# if use quantize kv, compress kv will be ignored now
|
# if use quantize kv, compress kv will be ignored now
|
||||||
if not isinstance(past_key_values, DynamicCompressCache):
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||||
|
|
@ -162,7 +162,7 @@ def llama_model_forward_4_38(
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads):
|
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif should_use_compresskv(input):
|
elif should_use_compresskv(input, input.shape[-1]):
|
||||||
# if use quantize kv, compress kv will be ignored now
|
# if use quantize kv, compress kv will be ignored now
|
||||||
if not isinstance(past_key_values, DynamicCompressCache):
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||||
|
|
@ -203,7 +203,7 @@ def llama_model_forward_4_41(
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads):
|
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif should_use_compresskv(input):
|
elif should_use_compresskv(input, input.shape[-1]):
|
||||||
# if use quantize kv, compress kv will be ignored now
|
# if use quantize kv, compress kv will be ignored now
|
||||||
if not isinstance(past_key_values, DynamicCompressCache):
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||||
|
|
@ -1283,6 +1283,7 @@ def llama_attention_forward_4_41_original(
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||||
|
from ipex_llm.transformers.kv import DynamicCompressCache
|
||||||
if "padding_mask" in kwargs:
|
if "padding_mask" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||||
|
|
@ -1295,7 +1296,7 @@ def llama_attention_forward_4_41_original(
|
||||||
original_dtype = hidden_states.dtype
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
# [CompressKV]
|
# [CompressKV]
|
||||||
use_compresskv = should_use_compresskv(hidden_states)
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
||||||
|
|
@ -1834,6 +1835,7 @@ def llama_attention_forward_4_38_original(
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||||
|
from ipex_llm.transformers.kv import DynamicCompressCache
|
||||||
if "padding_mask" in kwargs:
|
if "padding_mask" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||||
|
|
@ -1846,7 +1848,7 @@ def llama_attention_forward_4_38_original(
|
||||||
original_dtype = hidden_states.dtype
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
# [CompressKV]
|
# [CompressKV]
|
||||||
use_compresskv = should_use_compresskv(hidden_states)
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ from ipex_llm.transformers.models.utils import SILU
|
||||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||||
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
|
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36, should_use_compresskv
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
|
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
|
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
|
||||||
|
|
@ -111,6 +111,7 @@ def minicpm_attention_forward_original(
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||||
|
from ipex_llm.transformers.kv import DynamicCompressCache
|
||||||
if "padding_mask" in kwargs:
|
if "padding_mask" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||||
|
|
@ -122,6 +123,9 @@ def minicpm_attention_forward_original(
|
||||||
# for flash attention
|
# for flash attention
|
||||||
original_dtype = hidden_states.dtype
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
|
# [CompressKV]
|
||||||
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
||||||
no_tp = not self.config.pretraining_tp > 1
|
no_tp = not self.config.pretraining_tp > 1
|
||||||
|
|
@ -154,7 +158,11 @@ def minicpm_attention_forward_original(
|
||||||
self.rotary_emb.base,)
|
self.rotary_emb.base,)
|
||||||
kv_seq_len += 1
|
kv_seq_len += 1
|
||||||
# update past_key_value's seem_tokens and kv caches.
|
# update past_key_value's seem_tokens and kv caches.
|
||||||
if self.layer_idx == 0:
|
# [CompressKV]
|
||||||
|
if use_compresskv:
|
||||||
|
past_key_value.update_seen_tokens(self.layer_idx, q_len)
|
||||||
|
kv_seq_len = past_key_value.get_seq_length()
|
||||||
|
elif self.layer_idx == 0:
|
||||||
past_key_value.seen_tokens = kv_seq_len
|
past_key_value.seen_tokens = kv_seq_len
|
||||||
past_key_value.key_cache[self.layer_idx] = key_states
|
past_key_value.key_cache[self.layer_idx] = key_states
|
||||||
past_key_value.value_cache[self.layer_idx] = value_states
|
past_key_value.value_cache[self.layer_idx] = value_states
|
||||||
|
|
@ -256,6 +264,12 @@ def minicpm_attention_forward_original(
|
||||||
cos, sin, position_ids, "llama")
|
cos, sin, position_ids, "llama")
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
if use_compresskv:
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx,
|
||||||
|
query_states, attention_mask, self.num_key_value_groups,
|
||||||
|
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
|
||||||
|
else:
|
||||||
# update the number of seen tokens
|
# update the number of seen tokens
|
||||||
if self.layer_idx == 0:
|
if self.layer_idx == 0:
|
||||||
past_key_value.seen_tokens += key_states.shape[-2]
|
past_key_value.seen_tokens += key_states.shape[-2]
|
||||||
|
|
@ -312,6 +326,9 @@ def minicpm_attention_forward_original(
|
||||||
elif not self.training and not hidden_states.requires_grad and \
|
elif not self.training and not hidden_states.requires_grad and \
|
||||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
|
if use_compresskv:
|
||||||
|
# [CompressKV] set attention_mask = None
|
||||||
|
new_attention_mask = None
|
||||||
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
||||||
new_attention_mask)
|
new_attention_mask)
|
||||||
attn_output = attn_output.view(query_states.shape)
|
attn_output = attn_output.view(query_states.shape)
|
||||||
|
|
@ -600,14 +617,19 @@ def minicpm_model_forward(
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
from ipex_llm.transformers.kv import DynamicFp8Cache
|
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
input = input_ids if input_ids is not None else inputs_embeds
|
input = input_ids if input_ids is not None else inputs_embeds
|
||||||
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
|
if use_cache:
|
||||||
|
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
|
||||||
self.config.num_attention_heads //
|
self.config.num_attention_heads //
|
||||||
self.config.num_key_value_heads):
|
self.config.num_key_value_heads):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
elif should_use_compresskv(input, input.shape[-1]):
|
||||||
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
|
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
return minicpm_model_forward_internal(
|
return minicpm_model_forward_internal(
|
||||||
self=self,
|
self=self,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
|
@ -782,6 +804,8 @@ def minicpm_attention_forward_original_4_39(
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||||
|
from ipex_llm.transformers.kv import DynamicCompressCache
|
||||||
|
|
||||||
if "padding_mask" in kwargs:
|
if "padding_mask" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||||
|
|
@ -793,6 +817,9 @@ def minicpm_attention_forward_original_4_39(
|
||||||
# for flash attention
|
# for flash attention
|
||||||
original_dtype = hidden_states.dtype
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
|
# [CompressKV]
|
||||||
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
|
||||||
no_tp = not self.config.pretraining_tp > 1
|
no_tp = not self.config.pretraining_tp > 1
|
||||||
|
|
@ -825,7 +852,11 @@ def minicpm_attention_forward_original_4_39(
|
||||||
self.rotary_emb.base,)
|
self.rotary_emb.base,)
|
||||||
kv_seq_len += 1
|
kv_seq_len += 1
|
||||||
# update past_key_value's seem_tokens and kv caches.
|
# update past_key_value's seem_tokens and kv caches.
|
||||||
if self.layer_idx == 0:
|
# [CompressKV]
|
||||||
|
if use_compresskv:
|
||||||
|
past_key_value.update_seen_tokens(self.layer_idx, q_len)
|
||||||
|
kv_seq_len = past_key_value.get_seq_length()
|
||||||
|
elif self.layer_idx == 0:
|
||||||
past_key_value._seen_tokens = kv_seq_len
|
past_key_value._seen_tokens = kv_seq_len
|
||||||
past_key_value.key_cache[self.layer_idx] = key_states
|
past_key_value.key_cache[self.layer_idx] = key_states
|
||||||
past_key_value.value_cache[self.layer_idx] = value_states
|
past_key_value.value_cache[self.layer_idx] = value_states
|
||||||
|
|
@ -927,6 +958,12 @@ def minicpm_attention_forward_original_4_39(
|
||||||
cos, sin, position_ids, "llama")
|
cos, sin, position_ids, "llama")
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
if use_compresskv:
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx,
|
||||||
|
query_states, attention_mask, self.num_key_value_groups,
|
||||||
|
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
|
||||||
|
else:
|
||||||
# update the number of seen tokens
|
# update the number of seen tokens
|
||||||
if self.layer_idx == 0:
|
if self.layer_idx == 0:
|
||||||
past_key_value._seen_tokens += key_states.shape[-2]
|
past_key_value._seen_tokens += key_states.shape[-2]
|
||||||
|
|
@ -983,6 +1020,9 @@ def minicpm_attention_forward_original_4_39(
|
||||||
elif not self.training and not hidden_states.requires_grad and \
|
elif not self.training and not hidden_states.requires_grad and \
|
||||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
|
if use_compresskv:
|
||||||
|
# [CompressKV] set attention_mask = None
|
||||||
|
new_attention_mask = None
|
||||||
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
||||||
new_attention_mask)
|
new_attention_mask)
|
||||||
attn_output = attn_output.view(query_states.shape)
|
attn_output = attn_output.view(query_states.shape)
|
||||||
|
|
|
||||||
|
|
@ -210,7 +210,7 @@ def mistral_model_forward_4_36(
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads):
|
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif should_use_compresskv(input_ids):
|
elif should_use_compresskv(input_ids, input_ids.shape[-1]):
|
||||||
# if use quantize kv, compress kv will be ignored now
|
# if use quantize kv, compress kv will be ignored now
|
||||||
if not isinstance(past_key_values, DynamicCompressCache):
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||||
|
|
@ -902,13 +902,15 @@ def mistral_attention_forward_4_36_original(
|
||||||
use_cache: bool=False,
|
use_cache: bool=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||||
|
from ipex_llm.transformers.kv import DynamicCompressCache
|
||||||
|
|
||||||
bsz, q_len, hidden_size = hidden_states.size()
|
bsz, q_len, hidden_size = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
# for flash attention
|
# for flash attention
|
||||||
original_dtype = hidden_states.dtype
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
# [CompressKV]
|
# [CompressKV]
|
||||||
use_compresskv = should_use_compresskv(hidden_states)
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||||
|
|
@ -1156,13 +1158,14 @@ def mistral_attention_forward_4_39_original(
|
||||||
use_cache: bool=False,
|
use_cache: bool=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||||
|
from ipex_llm.transformers.kv import DynamicCompressCache
|
||||||
bsz, q_len, hidden_size = hidden_states.size()
|
bsz, q_len, hidden_size = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
# for flash attention
|
# for flash attention
|
||||||
original_dtype = hidden_states.dtype
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
# [CompressKV]
|
# [CompressKV]
|
||||||
use_compresskv = should_use_compresskv(hidden_states)
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||||
|
|
|
||||||
|
|
@ -118,7 +118,7 @@ def qwen2_model_forward(
|
||||||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||||
)
|
)
|
||||||
use_compress_kv = should_use_compresskv(inputs)
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1])
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
|
|
@ -401,7 +401,8 @@ def qwen2_attention_forward(
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
|
||||||
# [CompressKV]
|
# [CompressKV]
|
||||||
use_compresskv = should_use_compresskv(hidden_states)
|
from ipex_llm.transformers.kv import DynamicCompressCache
|
||||||
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
|
if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
|
||||||
qkv = self.qkv_proj(hidden_states)
|
qkv = self.qkv_proj(hidden_states)
|
||||||
|
|
|
||||||
|
|
@ -481,6 +481,13 @@ def update_past_key_value(past_key_value, key_states, value_states,
|
||||||
return key_states, value_states
|
return key_states, value_states
|
||||||
|
|
||||||
|
|
||||||
def should_use_compresskv(x: torch.Tensor):
|
def should_use_compresskv(x: torch.Tensor, prompt_len: int):
|
||||||
use_compress_kv = os.environ.get("IPEX_LLM_COMPRESS_KV_CACHE", None)
|
use_compress_kv = os.environ.get("IPEX_LLM_COMPRESS_KV_CACHE", None)
|
||||||
|
if use_compress_kv is None:
|
||||||
|
return (
|
||||||
|
get_xpu_device_type(x) == "mtl"
|
||||||
|
and prompt_len >= 2500
|
||||||
|
and prompt_len <= 4500
|
||||||
|
)
|
||||||
|
else:
|
||||||
return x.device.type == 'xpu' and use_compress_kv == "1"
|
return x.device.type == 'xpu' and use_compress_kv == "1"
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue