Support minicpm compresskv & modify default compresskv config & default enable compresskv on mtl 2.5k~4.5k (#11726)
* support minicpm & modify default & default enable on mtl 2.5k~4.5k * fix style
This commit is contained in:
		
							parent
							
								
									c093f7d980
								
							
						
					
					
						commit
						a71ae7c22b
					
				
					 8 changed files with 143 additions and 89 deletions
				
			
		| 
						 | 
				
			
			@ -146,13 +146,14 @@ def compress_kv(attn_config, key_states, query_states, value_states, attention_m
 | 
			
		|||
    if not hasattr(attn_config, 'window_size'):
 | 
			
		||||
        attn_config.window_size = 32
 | 
			
		||||
    if not hasattr(attn_config, 'max_capacity_prompt'):
 | 
			
		||||
        attn_config.max_capacity_prompt = 512
 | 
			
		||||
        attn_config.max_capacity_prompt = 1024
 | 
			
		||||
    if not hasattr(attn_config, 'kernel_size'):
 | 
			
		||||
        attn_config.kernel_size = 5
 | 
			
		||||
        attn_config.kernel_size = 7
 | 
			
		||||
    if not hasattr(attn_config, 'pooling'):
 | 
			
		||||
        attn_config.pooling = 'avgpool'
 | 
			
		||||
        attn_config.pooling = 'maxpool'
 | 
			
		||||
    bsz, num_heads, q_len, head_dim = query_states.shape
 | 
			
		||||
    if q_len < attn_config.max_capacity_prompt:
 | 
			
		||||
    print(f"attn_config.max_capacity_prompt: ", attn_config.max_capacity_prompt, " ", q_len)
 | 
			
		||||
    if q_len <= attn_config.max_capacity_prompt:
 | 
			
		||||
        return key_states, value_states
 | 
			
		||||
    else:
 | 
			
		||||
        key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -87,7 +87,7 @@ def chatglm2_model_forward(
 | 
			
		|||
                                dtype=inputs_embeds.dtype, device=inputs_embeds.device)
 | 
			
		||||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        use_compress_kv = should_use_compresskv(input_ids)
 | 
			
		||||
        use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[-1])
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
 | 
			
		||||
                                                input_ids)
 | 
			
		||||
        if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,7 +50,7 @@ def chatglm4_model_forward(
 | 
			
		|||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        inputs = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
        use_compress_kv = should_use_compresskv(inputs)
 | 
			
		||||
        use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1])
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
 | 
			
		||||
                                                inputs)
 | 
			
		||||
        if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -122,7 +122,7 @@ def llama_model_forward_4_36(
 | 
			
		|||
                                 self.config.num_attention_heads//self.config.num_key_value_heads):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif should_use_compresskv(input):
 | 
			
		||||
        elif should_use_compresskv(input, input.shape[-1]):
 | 
			
		||||
            # if use quantize kv, compress kv will be ignored now
 | 
			
		||||
            if not isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(
 | 
			
		||||
| 
						 | 
				
			
			@ -162,7 +162,7 @@ def llama_model_forward_4_38(
 | 
			
		|||
                                 self.config.num_attention_heads//self.config.num_key_value_heads):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif should_use_compresskv(input):
 | 
			
		||||
        elif should_use_compresskv(input, input.shape[-1]):
 | 
			
		||||
            # if use quantize kv, compress kv will be ignored now
 | 
			
		||||
            if not isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(
 | 
			
		||||
| 
						 | 
				
			
			@ -203,7 +203,7 @@ def llama_model_forward_4_41(
 | 
			
		|||
                                 self.config.num_attention_heads//self.config.num_key_value_heads):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif should_use_compresskv(input):
 | 
			
		||||
        elif should_use_compresskv(input, input.shape[-1]):
 | 
			
		||||
            # if use quantize kv, compress kv will be ignored now
 | 
			
		||||
            if not isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(
 | 
			
		||||
| 
						 | 
				
			
			@ -1283,6 +1283,7 @@ def llama_attention_forward_4_41_original(
 | 
			
		|||
    cache_position: Optional[torch.LongTensor] = None,
 | 
			
		||||
    **kwargs
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicCompressCache
 | 
			
		||||
    if "padding_mask" in kwargs:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
            "Passing `padding_mask` is deprecated and will be removed in v4.37. "
 | 
			
		||||
| 
						 | 
				
			
			@ -1295,7 +1296,7 @@ def llama_attention_forward_4_41_original(
 | 
			
		|||
    original_dtype = hidden_states.dtype
 | 
			
		||||
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    use_compresskv = should_use_compresskv(hidden_states)
 | 
			
		||||
    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
 | 
			
		||||
| 
						 | 
				
			
			@ -1834,6 +1835,7 @@ def llama_attention_forward_4_38_original(
 | 
			
		|||
    cache_position: Optional[torch.LongTensor] = None,
 | 
			
		||||
    **kwargs
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicCompressCache
 | 
			
		||||
    if "padding_mask" in kwargs:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
            "Passing `padding_mask` is deprecated and will be removed in v4.37. "
 | 
			
		||||
| 
						 | 
				
			
			@ -1846,7 +1848,7 @@ def llama_attention_forward_4_38_original(
 | 
			
		|||
    original_dtype = hidden_states.dtype
 | 
			
		||||
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    use_compresskv = should_use_compresskv(hidden_states)
 | 
			
		||||
    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -49,7 +49,7 @@ from ipex_llm.transformers.models.utils import SILU
 | 
			
		|||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
 | 
			
		||||
    restore_fp8_kv_cache, use_quantize_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
			
		||||
    apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
 | 
			
		||||
    apply_rotary_pos_emb, is_enough_kv_cache_room_4_36, should_use_compresskv
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
 | 
			
		||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
 | 
			
		||||
| 
						 | 
				
			
			@ -111,6 +111,7 @@ def minicpm_attention_forward_original(
 | 
			
		|||
    cache_position: Optional[torch.LongTensor] = None,
 | 
			
		||||
    **kwargs
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicCompressCache
 | 
			
		||||
    if "padding_mask" in kwargs:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
            "Passing `padding_mask` is deprecated and will be removed in v4.37. "
 | 
			
		||||
| 
						 | 
				
			
			@ -122,6 +123,9 @@ def minicpm_attention_forward_original(
 | 
			
		|||
    # for flash attention
 | 
			
		||||
    original_dtype = hidden_states.dtype
 | 
			
		||||
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
 | 
			
		||||
    no_tp = not self.config.pretraining_tp > 1
 | 
			
		||||
| 
						 | 
				
			
			@ -154,7 +158,11 @@ def minicpm_attention_forward_original(
 | 
			
		|||
                                                                       self.rotary_emb.base,)
 | 
			
		||||
        kv_seq_len += 1
 | 
			
		||||
        # update past_key_value's seem_tokens and kv caches.
 | 
			
		||||
        if self.layer_idx == 0:
 | 
			
		||||
        # [CompressKV]
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            past_key_value.update_seen_tokens(self.layer_idx, q_len)
 | 
			
		||||
            kv_seq_len = past_key_value.get_seq_length()
 | 
			
		||||
        elif self.layer_idx == 0:
 | 
			
		||||
            past_key_value.seen_tokens = kv_seq_len
 | 
			
		||||
        past_key_value.key_cache[self.layer_idx] = key_states
 | 
			
		||||
        past_key_value.value_cache[self.layer_idx] = value_states
 | 
			
		||||
| 
						 | 
				
			
			@ -256,6 +264,12 @@ def minicpm_attention_forward_original(
 | 
			
		|||
                                                                cos, sin, position_ids, "llama")
 | 
			
		||||
 | 
			
		||||
        if past_key_value is not None:
 | 
			
		||||
            if use_compresskv:
 | 
			
		||||
                key_states, value_states = past_key_value.update(
 | 
			
		||||
                    key_states, value_states, self.layer_idx,
 | 
			
		||||
                    query_states, attention_mask, self.num_key_value_groups,
 | 
			
		||||
                    self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
 | 
			
		||||
            else:
 | 
			
		||||
                # update the number of seen tokens
 | 
			
		||||
                if self.layer_idx == 0:
 | 
			
		||||
                    past_key_value.seen_tokens += key_states.shape[-2]
 | 
			
		||||
| 
						 | 
				
			
			@ -312,6 +326,9 @@ def minicpm_attention_forward_original(
 | 
			
		|||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            # [CompressKV] set attention_mask = None
 | 
			
		||||
            new_attention_mask = None
 | 
			
		||||
        attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
			
		||||
                                    new_attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
| 
						 | 
				
			
			@ -600,14 +617,19 @@ def minicpm_model_forward(
 | 
			
		|||
    output_hidden_states: Optional[bool] = None,
 | 
			
		||||
    return_dict: Optional[bool] = None,
 | 
			
		||||
) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    input = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
 | 
			
		||||
                                 self.config.num_attention_heads //
 | 
			
		||||
                                 self.config.num_key_value_heads):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif should_use_compresskv(input, input.shape[-1]):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
    return minicpm_model_forward_internal(
 | 
			
		||||
        self=self,
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
| 
						 | 
				
			
			@ -782,6 +804,8 @@ def minicpm_attention_forward_original_4_39(
 | 
			
		|||
    cache_position: Optional[torch.LongTensor] = None,
 | 
			
		||||
    **kwargs
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicCompressCache
 | 
			
		||||
 | 
			
		||||
    if "padding_mask" in kwargs:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
            "Passing `padding_mask` is deprecated and will be removed in v4.37. "
 | 
			
		||||
| 
						 | 
				
			
			@ -793,6 +817,9 @@ def minicpm_attention_forward_original_4_39(
 | 
			
		|||
    # for flash attention
 | 
			
		||||
    original_dtype = hidden_states.dtype
 | 
			
		||||
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
 | 
			
		||||
    no_tp = not self.config.pretraining_tp > 1
 | 
			
		||||
| 
						 | 
				
			
			@ -825,7 +852,11 @@ def minicpm_attention_forward_original_4_39(
 | 
			
		|||
                                                                       self.rotary_emb.base,)
 | 
			
		||||
        kv_seq_len += 1
 | 
			
		||||
        # update past_key_value's seem_tokens and kv caches.
 | 
			
		||||
        if self.layer_idx == 0:
 | 
			
		||||
        # [CompressKV]
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            past_key_value.update_seen_tokens(self.layer_idx, q_len)
 | 
			
		||||
            kv_seq_len = past_key_value.get_seq_length()
 | 
			
		||||
        elif self.layer_idx == 0:
 | 
			
		||||
            past_key_value._seen_tokens = kv_seq_len
 | 
			
		||||
        past_key_value.key_cache[self.layer_idx] = key_states
 | 
			
		||||
        past_key_value.value_cache[self.layer_idx] = value_states
 | 
			
		||||
| 
						 | 
				
			
			@ -927,6 +958,12 @@ def minicpm_attention_forward_original_4_39(
 | 
			
		|||
                                                                cos, sin, position_ids, "llama")
 | 
			
		||||
 | 
			
		||||
        if past_key_value is not None:
 | 
			
		||||
            if use_compresskv:
 | 
			
		||||
                key_states, value_states = past_key_value.update(
 | 
			
		||||
                    key_states, value_states, self.layer_idx,
 | 
			
		||||
                    query_states, attention_mask, self.num_key_value_groups,
 | 
			
		||||
                    self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
 | 
			
		||||
            else:
 | 
			
		||||
                # update the number of seen tokens
 | 
			
		||||
                if self.layer_idx == 0:
 | 
			
		||||
                    past_key_value._seen_tokens += key_states.shape[-2]
 | 
			
		||||
| 
						 | 
				
			
			@ -983,6 +1020,9 @@ def minicpm_attention_forward_original_4_39(
 | 
			
		|||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            # [CompressKV] set attention_mask = None
 | 
			
		||||
            new_attention_mask = None
 | 
			
		||||
        attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
			
		||||
                                    new_attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -210,7 +210,7 @@ def mistral_model_forward_4_36(
 | 
			
		|||
                                 self.config.num_attention_heads//self.config.num_key_value_heads):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif should_use_compresskv(input_ids):
 | 
			
		||||
        elif should_use_compresskv(input_ids, input_ids.shape[-1]):
 | 
			
		||||
            # if use quantize kv, compress kv will be ignored now
 | 
			
		||||
            if not isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(
 | 
			
		||||
| 
						 | 
				
			
			@ -902,13 +902,15 @@ def mistral_attention_forward_4_36_original(
 | 
			
		|||
    use_cache: bool=False,
 | 
			
		||||
    **kwargs
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicCompressCache
 | 
			
		||||
 | 
			
		||||
    bsz, q_len, hidden_size = hidden_states.size()
 | 
			
		||||
    device = hidden_states.device
 | 
			
		||||
    # for flash attention
 | 
			
		||||
    original_dtype = hidden_states.dtype
 | 
			
		||||
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    use_compresskv = should_use_compresskv(hidden_states)
 | 
			
		||||
    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
			
		||||
| 
						 | 
				
			
			@ -1156,13 +1158,14 @@ def mistral_attention_forward_4_39_original(
 | 
			
		|||
    use_cache: bool=False,
 | 
			
		||||
    **kwargs
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicCompressCache
 | 
			
		||||
    bsz, q_len, hidden_size = hidden_states.size()
 | 
			
		||||
    device = hidden_states.device
 | 
			
		||||
    # for flash attention
 | 
			
		||||
    original_dtype = hidden_states.dtype
 | 
			
		||||
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    use_compresskv = should_use_compresskv(hidden_states)
 | 
			
		||||
    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -118,7 +118,7 @@ def qwen2_model_forward(
 | 
			
		|||
        and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
 | 
			
		||||
                                  self.config.num_attention_heads//self.config.num_key_value_heads)
 | 
			
		||||
    )
 | 
			
		||||
    use_compress_kv = should_use_compresskv(inputs)
 | 
			
		||||
    use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1])
 | 
			
		||||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
| 
						 | 
				
			
			@ -401,7 +401,8 @@ def qwen2_attention_forward(
 | 
			
		|||
    device = hidden_states.device
 | 
			
		||||
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    use_compresskv = should_use_compresskv(hidden_states)
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicCompressCache
 | 
			
		||||
    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
			
		||||
 | 
			
		||||
    if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
 | 
			
		||||
        qkv = self.qkv_proj(hidden_states)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -481,6 +481,13 @@ def update_past_key_value(past_key_value, key_states, value_states,
 | 
			
		|||
    return key_states, value_states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_use_compresskv(x: torch.Tensor):
 | 
			
		||||
def should_use_compresskv(x: torch.Tensor, prompt_len: int):
 | 
			
		||||
    use_compress_kv = os.environ.get("IPEX_LLM_COMPRESS_KV_CACHE", None)
 | 
			
		||||
    if use_compress_kv is None:
 | 
			
		||||
        return (
 | 
			
		||||
            get_xpu_device_type(x) == "mtl"
 | 
			
		||||
            and prompt_len >= 2500
 | 
			
		||||
            and prompt_len <= 4500
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        return x.device.type == 'xpu' and use_compress_kv == "1"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue