diff --git a/python/llm/src/ipex_llm/transformers/kv.py b/python/llm/src/ipex_llm/transformers/kv.py index ae253239..e2f386ea 100644 --- a/python/llm/src/ipex_llm/transformers/kv.py +++ b/python/llm/src/ipex_llm/transformers/kv.py @@ -146,13 +146,14 @@ def compress_kv(attn_config, key_states, query_states, value_states, attention_m if not hasattr(attn_config, 'window_size'): attn_config.window_size = 32 if not hasattr(attn_config, 'max_capacity_prompt'): - attn_config.max_capacity_prompt = 512 + attn_config.max_capacity_prompt = 1024 if not hasattr(attn_config, 'kernel_size'): - attn_config.kernel_size = 5 + attn_config.kernel_size = 7 if not hasattr(attn_config, 'pooling'): - attn_config.pooling = 'avgpool' + attn_config.pooling = 'maxpool' bsz, num_heads, q_len, head_dim = query_states.shape - if q_len < attn_config.max_capacity_prompt: + print(f"attn_config.max_capacity_prompt: ", attn_config.max_capacity_prompt, " ", q_len) + if q_len <= attn_config.max_capacity_prompt: return key_states, value_states else: key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 006ff331..fe504960 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -87,7 +87,7 @@ def chatglm2_model_forward( dtype=inputs_embeds.dtype, device=inputs_embeds.device) if use_cache: - use_compress_kv = should_use_compresskv(input_ids) + use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[-1]) use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h, input_ids) if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values, diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index 361d405c..46d7d780 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -50,7 +50,7 @@ def chatglm4_model_forward( if use_cache: inputs = input_ids if input_ids is not None else inputs_embeds - use_compress_kv = should_use_compresskv(inputs) + use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1]) use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h, inputs) if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values, diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index fd7ecffd..805a129e 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -122,7 +122,7 @@ def llama_model_forward_4_36( self.config.num_attention_heads//self.config.num_key_value_heads): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif should_use_compresskv(input): + elif should_use_compresskv(input, input.shape[-1]): # if use quantize kv, compress kv will be ignored now if not isinstance(past_key_values, DynamicCompressCache): past_key_values = DynamicCompressCache.from_legacy_cache( @@ -162,7 +162,7 @@ def llama_model_forward_4_38( self.config.num_attention_heads//self.config.num_key_value_heads): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif should_use_compresskv(input): + elif should_use_compresskv(input, input.shape[-1]): # if use quantize kv, compress kv will be ignored now if not isinstance(past_key_values, DynamicCompressCache): past_key_values = DynamicCompressCache.from_legacy_cache( @@ -203,7 +203,7 @@ def llama_model_forward_4_41( self.config.num_attention_heads//self.config.num_key_value_heads): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif should_use_compresskv(input): + elif should_use_compresskv(input, input.shape[-1]): # if use quantize kv, compress kv will be ignored now if not isinstance(past_key_values, DynamicCompressCache): past_key_values = DynamicCompressCache.from_legacy_cache( @@ -1283,6 +1283,7 @@ def llama_attention_forward_4_41_original( cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: + from ipex_llm.transformers.kv import DynamicCompressCache if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. " @@ -1295,7 +1296,7 @@ def llama_attention_forward_4_41_original( original_dtype = hidden_states.dtype # [CompressKV] - use_compresskv = should_use_compresskv(hidden_states) + use_compresskv = isinstance(past_key_value, DynamicCompressCache) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) @@ -1834,6 +1835,7 @@ def llama_attention_forward_4_38_original( cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: + from ipex_llm.transformers.kv import DynamicCompressCache if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. " @@ -1846,7 +1848,7 @@ def llama_attention_forward_4_38_original( original_dtype = hidden_states.dtype # [CompressKV] - use_compresskv = should_use_compresskv(hidden_states) + use_compresskv = isinstance(past_key_value, DynamicCompressCache) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm.py b/python/llm/src/ipex_llm/transformers/models/minicpm.py index 29db361c..50c0b9ce 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm.py @@ -49,7 +49,7 @@ from ipex_llm.transformers.models.utils import SILU from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ restore_fp8_kv_cache, use_quantize_kv_cache from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ - apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 + apply_rotary_pos_emb, is_enough_kv_cache_room_4_36, should_use_compresskv from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8 from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check @@ -111,6 +111,7 @@ def minicpm_attention_forward_original( cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: + from ipex_llm.transformers.kv import DynamicCompressCache if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. " @@ -122,6 +123,9 @@ def minicpm_attention_forward_original( # for flash attention original_dtype = hidden_states.dtype + # [CompressKV] + use_compresskv = isinstance(past_key_value, DynamicCompressCache) + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) no_tp = not self.config.pretraining_tp > 1 @@ -154,7 +158,11 @@ def minicpm_attention_forward_original( self.rotary_emb.base,) kv_seq_len += 1 # update past_key_value's seem_tokens and kv caches. - if self.layer_idx == 0: + # [CompressKV] + if use_compresskv: + past_key_value.update_seen_tokens(self.layer_idx, q_len) + kv_seq_len = past_key_value.get_seq_length() + elif self.layer_idx == 0: past_key_value.seen_tokens = kv_seq_len past_key_value.key_cache[self.layer_idx] = key_states past_key_value.value_cache[self.layer_idx] = value_states @@ -256,42 +264,48 @@ def minicpm_attention_forward_original( cos, sin, position_ids, "llama") if past_key_value is not None: - # update the number of seen tokens - if self.layer_idx == 0: - past_key_value.seen_tokens += key_states.shape[-2] - - # reuse k, v, self_attention - # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` - if len(past_key_value.key_cache) <= self.layer_idx: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) + if use_compresskv: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, + query_states, attention_mask, self.num_key_value_groups, + self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) else: - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] + # update the number of seen tokens + if self.layer_idx == 0: + past_key_value.seen_tokens += key_states.shape[-2] - if not enough_kv_room: - # allocate new - new_c_k, new_c_v = extend_kv_cache(bsz, - self.num_key_value_heads, # Support GQA - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) + # reuse k, v, self_attention + # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` + if len(past_key_value.key_cache) <= self.layer_idx: + past_key_value.key_cache.append(key_states) + past_key_value.value_cache.append(value_states) + else: + cache_k = past_key_value.key_cache[self.layer_idx] + cache_v = past_key_value.value_cache[self.layer_idx] - new_c_k[:] = cache_k - new_c_v[:] = cache_v - cache_k = new_c_k - cache_v = new_c_v + if not enough_kv_room: + # allocate new + new_c_k, new_c_v = extend_kv_cache(bsz, + self.num_key_value_heads, # Support GQA + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) - key_states, value_states = append_kv_cache(cache_k, - cache_v, - key_states, - value_states) + new_c_k[:] = cache_k + new_c_v[:] = cache_v + cache_k = new_c_k + cache_v = new_c_v - # update past_key_value - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states + key_states, value_states = append_kv_cache(cache_k, + cache_v, + key_states, + value_states) + + # update past_key_value + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states if cache_position is not None: new_attention_mask = attention_mask[:, :, kv_seq_len - q_len:kv_seq_len, 0:kv_seq_len] @@ -312,6 +326,9 @@ def minicpm_attention_forward_original( elif not self.training and not hidden_states.requires_grad and \ use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import xe_addons + if use_compresskv: + # [CompressKV] set attention_mask = None + new_attention_mask = None attn_output = xe_addons.sdp(query_states, key_states, value_states, new_attention_mask) attn_output = attn_output.view(query_states.shape) @@ -600,14 +617,19 @@ def minicpm_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - from ipex_llm.transformers.kv import DynamicFp8Cache + from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache use_cache = use_cache if use_cache is not None else self.config.use_cache input = input_ids if input_ids is not None else inputs_embeds - if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input, - self.config.num_attention_heads // - self.config.num_key_value_heads): - if not isinstance(past_key_values, DynamicFp8Cache): - past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + if use_cache: + if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input, + self.config.num_attention_heads // + self.config.num_key_value_heads): + if not isinstance(past_key_values, DynamicFp8Cache): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + elif should_use_compresskv(input, input.shape[-1]): + if not isinstance(past_key_values, DynamicCompressCache): + past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) + return minicpm_model_forward_internal( self=self, input_ids=input_ids, @@ -782,6 +804,8 @@ def minicpm_attention_forward_original_4_39( cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: + from ipex_llm.transformers.kv import DynamicCompressCache + if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. " @@ -793,6 +817,9 @@ def minicpm_attention_forward_original_4_39( # for flash attention original_dtype = hidden_states.dtype + # [CompressKV] + use_compresskv = isinstance(past_key_value, DynamicCompressCache) + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) no_tp = not self.config.pretraining_tp > 1 @@ -825,7 +852,11 @@ def minicpm_attention_forward_original_4_39( self.rotary_emb.base,) kv_seq_len += 1 # update past_key_value's seem_tokens and kv caches. - if self.layer_idx == 0: + # [CompressKV] + if use_compresskv: + past_key_value.update_seen_tokens(self.layer_idx, q_len) + kv_seq_len = past_key_value.get_seq_length() + elif self.layer_idx == 0: past_key_value._seen_tokens = kv_seq_len past_key_value.key_cache[self.layer_idx] = key_states past_key_value.value_cache[self.layer_idx] = value_states @@ -927,42 +958,48 @@ def minicpm_attention_forward_original_4_39( cos, sin, position_ids, "llama") if past_key_value is not None: - # update the number of seen tokens - if self.layer_idx == 0: - past_key_value._seen_tokens += key_states.shape[-2] - - # reuse k, v, self_attention - # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` - if len(past_key_value.key_cache) <= self.layer_idx: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) + if use_compresskv: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, + query_states, attention_mask, self.num_key_value_groups, + self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) else: - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] + # update the number of seen tokens + if self.layer_idx == 0: + past_key_value._seen_tokens += key_states.shape[-2] - if not enough_kv_room: - # allocate new - new_c_k, new_c_v = extend_kv_cache(bsz, - self.num_key_value_heads, # Support GQA - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) + # reuse k, v, self_attention + # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` + if len(past_key_value.key_cache) <= self.layer_idx: + past_key_value.key_cache.append(key_states) + past_key_value.value_cache.append(value_states) + else: + cache_k = past_key_value.key_cache[self.layer_idx] + cache_v = past_key_value.value_cache[self.layer_idx] - new_c_k[:] = cache_k - new_c_v[:] = cache_v - cache_k = new_c_k - cache_v = new_c_v + if not enough_kv_room: + # allocate new + new_c_k, new_c_v = extend_kv_cache(bsz, + self.num_key_value_heads, # Support GQA + self.head_dim, + cache_k.size(2), + kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=device) - key_states, value_states = append_kv_cache(cache_k, - cache_v, - key_states, - value_states) + new_c_k[:] = cache_k + new_c_v[:] = cache_v + cache_k = new_c_k + cache_v = new_c_v - # update past_key_value - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states + key_states, value_states = append_kv_cache(cache_k, + cache_v, + key_states, + value_states) + + # update past_key_value + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states if cache_position is not None: new_attention_mask = attention_mask[:, :, kv_seq_len - q_len:kv_seq_len, 0:kv_seq_len] @@ -983,6 +1020,9 @@ def minicpm_attention_forward_original_4_39( elif not self.training and not hidden_states.requires_grad and \ use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import xe_addons + if use_compresskv: + # [CompressKV] set attention_mask = None + new_attention_mask = None attn_output = xe_addons.sdp(query_states, key_states, value_states, new_attention_mask) attn_output = attn_output.view(query_states.shape) diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index bc66b77f..35d7abae 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -210,7 +210,7 @@ def mistral_model_forward_4_36( self.config.num_attention_heads//self.config.num_key_value_heads): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif should_use_compresskv(input_ids): + elif should_use_compresskv(input_ids, input_ids.shape[-1]): # if use quantize kv, compress kv will be ignored now if not isinstance(past_key_values, DynamicCompressCache): past_key_values = DynamicCompressCache.from_legacy_cache( @@ -902,13 +902,15 @@ def mistral_attention_forward_4_36_original( use_cache: bool=False, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + from ipex_llm.transformers.kv import DynamicCompressCache + bsz, q_len, hidden_size = hidden_states.size() device = hidden_states.device # for flash attention original_dtype = hidden_states.dtype # [CompressKV] - use_compresskv = should_use_compresskv(hidden_states) + use_compresskv = isinstance(past_key_value, DynamicCompressCache) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) @@ -1156,13 +1158,14 @@ def mistral_attention_forward_4_39_original( use_cache: bool=False, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + from ipex_llm.transformers.kv import DynamicCompressCache bsz, q_len, hidden_size = hidden_states.size() device = hidden_states.device # for flash attention original_dtype = hidden_states.dtype # [CompressKV] - use_compresskv = should_use_compresskv(hidden_states) + use_compresskv = isinstance(past_key_value, DynamicCompressCache) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 4bf7ae1b..5a56add8 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -118,7 +118,7 @@ def qwen2_model_forward( and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, self.config.num_attention_heads//self.config.num_key_value_heads) ) - use_compress_kv = should_use_compresskv(inputs) + use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1]) if use_cache: if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): @@ -401,7 +401,8 @@ def qwen2_attention_forward( device = hidden_states.device # [CompressKV] - use_compresskv = should_use_compresskv(hidden_states) + from ipex_llm.transformers.kv import DynamicCompressCache + use_compresskv = isinstance(past_key_value, DynamicCompressCache) if hasattr(self, 'qkv_proj') and self.qkv_proj is not None: qkv = self.qkv_proj(hidden_states) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 5ebdeaa6..0b344c69 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -481,6 +481,13 @@ def update_past_key_value(past_key_value, key_states, value_states, return key_states, value_states -def should_use_compresskv(x: torch.Tensor): +def should_use_compresskv(x: torch.Tensor, prompt_len: int): use_compress_kv = os.environ.get("IPEX_LLM_COMPRESS_KV_CACHE", None) - return x.device.type == 'xpu' and use_compress_kv == "1" + if use_compress_kv is None: + return ( + get_xpu_device_type(x) == "mtl" + and prompt_len >= 2500 + and prompt_len <= 4500 + ) + else: + return x.device.type == 'xpu' and use_compress_kv == "1"