Support compress kv with lookahead (#11752)
* support compress kv with lookahead * enough kv miss param
This commit is contained in:
		
							parent
							
								
									93455aac09
								
							
						
					
					
						commit
						4b9c57cc60
					
				
					 7 changed files with 32 additions and 12 deletions
				
			
		| 
						 | 
					@ -287,7 +287,9 @@ def chatglm2_attention_forward(
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        from transformers.configuration_utils import PretrainedConfig
 | 
					        from transformers.configuration_utils import PretrainedConfig
 | 
				
			||||||
        self.config = self.config if hasattr(self, "config") else PretrainedConfig()
 | 
					        self.config = self.config if hasattr(self, "config") else PretrainedConfig()
 | 
				
			||||||
        enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1)
 | 
					        enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
 | 
				
			||||||
 | 
					                                                      self.layer_number - 1,
 | 
				
			||||||
 | 
					                                                      q_len)
 | 
				
			||||||
        key_states, value_states = past_key_value.update(
 | 
					        key_states, value_states = past_key_value.update(
 | 
				
			||||||
            key_states, value_states, self.layer_number - 1,
 | 
					            key_states, value_states, self.layer_number - 1,
 | 
				
			||||||
            query_states, attention_mask, n_head // n_kv_head,
 | 
					            query_states, attention_mask, n_head // n_kv_head,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -213,7 +213,9 @@ def chatglm4_attention_forward(
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        from transformers.configuration_utils import PretrainedConfig
 | 
					        from transformers.configuration_utils import PretrainedConfig
 | 
				
			||||||
        self.config = self.config if hasattr(self, "config") else PretrainedConfig()
 | 
					        self.config = self.config if hasattr(self, "config") else PretrainedConfig()
 | 
				
			||||||
        enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1)
 | 
					        enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
 | 
				
			||||||
 | 
					                                                      self.layer_number - 1,
 | 
				
			||||||
 | 
					                                                      q_len)
 | 
				
			||||||
        key_states, value_states = past_key_value.update(
 | 
					        key_states, value_states = past_key_value.update(
 | 
				
			||||||
            key_states, value_states, self.layer_number - 1,
 | 
					            key_states, value_states, self.layer_number - 1,
 | 
				
			||||||
            query_states, attention_mask, n_head // n_kv_head,
 | 
					            query_states, attention_mask, n_head // n_kv_head,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -127,7 +127,8 @@ def minicpm_attention_forward_original(
 | 
				
			||||||
    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
					    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
					    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
 | 
					    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
 | 
				
			||||||
 | 
					                                                  seq_len=q_len)
 | 
				
			||||||
    no_tp = not self.config.pretraining_tp > 1
 | 
					    no_tp = not self.config.pretraining_tp > 1
 | 
				
			||||||
    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
					    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
				
			||||||
                                                use_fuse_rope,
 | 
					                                                use_fuse_rope,
 | 
				
			||||||
| 
						 | 
					@ -408,7 +409,8 @@ def minicpm_attention_forward_quantized(
 | 
				
			||||||
    bsz, q_len, _ = hidden_states.size()
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
    device = hidden_states.device
 | 
					    device = hidden_states.device
 | 
				
			||||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
					    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
 | 
					    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
 | 
				
			||||||
 | 
					                                                  seq_len=q_len)
 | 
				
			||||||
    no_tp = not self.config.pretraining_tp > 1
 | 
					    no_tp = not self.config.pretraining_tp > 1
 | 
				
			||||||
    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
					    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
				
			||||||
                                                use_fuse_rope,
 | 
					                                                use_fuse_rope,
 | 
				
			||||||
| 
						 | 
					@ -821,7 +823,8 @@ def minicpm_attention_forward_original_4_39(
 | 
				
			||||||
    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
					    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
					    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
 | 
					    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
 | 
				
			||||||
 | 
					                                                  seq_len=q_len)
 | 
				
			||||||
    no_tp = not self.config.pretraining_tp > 1
 | 
					    no_tp = not self.config.pretraining_tp > 1
 | 
				
			||||||
    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
					    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
				
			||||||
                                                use_fuse_rope,
 | 
					                                                use_fuse_rope,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -699,7 +699,8 @@ def mistral_attention_forward_4_36_quantized(
 | 
				
			||||||
    original_dtype = hidden_states.dtype
 | 
					    original_dtype = hidden_states.dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
					    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
 | 
					    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
 | 
				
			||||||
 | 
					                                                  seq_len=q_len)
 | 
				
			||||||
    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
					    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
				
			||||||
                                                use_fuse_rope,
 | 
					                                                use_fuse_rope,
 | 
				
			||||||
                                                enough_kv_room,
 | 
					                                                enough_kv_room,
 | 
				
			||||||
| 
						 | 
					@ -916,7 +917,9 @@ def mistral_attention_forward_4_36_original(
 | 
				
			||||||
    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
					    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
					    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
					    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
 | 
				
			||||||
 | 
					                                                  self.layer_idx,
 | 
				
			||||||
 | 
					                                                  q_len)
 | 
				
			||||||
    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
					    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
				
			||||||
                                                use_fuse_rope,
 | 
					                                                use_fuse_rope,
 | 
				
			||||||
                                                enough_kv_room,
 | 
					                                                enough_kv_room,
 | 
				
			||||||
| 
						 | 
					@ -1172,7 +1175,8 @@ def mistral_attention_forward_4_39_original(
 | 
				
			||||||
    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
					    use_compresskv = isinstance(past_key_value, DynamicCompressCache)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
					    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
					    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
 | 
				
			||||||
 | 
					                                                  q_len)
 | 
				
			||||||
    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
					    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
				
			||||||
                                                use_fuse_rope,
 | 
					                                                use_fuse_rope,
 | 
				
			||||||
                                                enough_kv_room,
 | 
					                                                enough_kv_room,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -135,7 +135,9 @@ def attention_forward(
 | 
				
			||||||
    if past_key_value is not None:
 | 
					    if past_key_value is not None:
 | 
				
			||||||
        # [CompressKV]
 | 
					        # [CompressKV]
 | 
				
			||||||
        if use_compresskv:
 | 
					        if use_compresskv:
 | 
				
			||||||
            enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
					            enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
 | 
				
			||||||
 | 
					                                                          self.layer_idx,
 | 
				
			||||||
 | 
					                                                          q_len)
 | 
				
			||||||
            key_states, value_states = past_key_value.update(
 | 
					            key_states, value_states = past_key_value.update(
 | 
				
			||||||
                key_states, value_states, self.layer_idx,
 | 
					                key_states, value_states, self.layer_idx,
 | 
				
			||||||
                query_states, attention_mask, self.num_key_value_groups,
 | 
					                query_states, attention_mask, self.num_key_value_groups,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -440,7 +440,8 @@ def qwen2_attention_forward(
 | 
				
			||||||
    if past_key_value is not None:
 | 
					    if past_key_value is not None:
 | 
				
			||||||
        # [CompressKV]
 | 
					        # [CompressKV]
 | 
				
			||||||
        if use_compresskv:
 | 
					        if use_compresskv:
 | 
				
			||||||
            enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
					            enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
 | 
				
			||||||
 | 
					                                                          q_len)
 | 
				
			||||||
            key_states, value_states = past_key_value.update(
 | 
					            key_states, value_states = past_key_value.update(
 | 
				
			||||||
                key_states, value_states, self.layer_idx,
 | 
					                key_states, value_states, self.layer_idx,
 | 
				
			||||||
                query_states, attention_mask, self.num_key_value_groups,
 | 
					                query_states, attention_mask, self.num_key_value_groups,
 | 
				
			||||||
| 
						 | 
					@ -471,6 +472,8 @@ def qwen2_attention_forward(
 | 
				
			||||||
                           is_causal=True).to(hidden_states.dtype)
 | 
					                           is_causal=True).to(hidden_states.dtype)
 | 
				
			||||||
    elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
					    elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
				
			||||||
        import xe_addons
 | 
					        import xe_addons
 | 
				
			||||||
 | 
					        if use_compresskv:
 | 
				
			||||||
 | 
					            attention_mask = None
 | 
				
			||||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
					        if isinstance(past_key_value, DynamicFp8Cache):
 | 
				
			||||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
					            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
				
			||||||
                                            attention_mask)
 | 
					                                            attention_mask)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -460,12 +460,16 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False):
 | 
					def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False):
 | 
				
			||||||
    if version.parse(trans_version) >= version.parse("4.36.0"):
 | 
					    if version.parse(trans_version) >= version.parse("4.36.0"):
 | 
				
			||||||
        from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
 | 
					        from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache,\
 | 
				
			||||||
        if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache)):
 | 
					            DynamicCompressCache
 | 
				
			||||||
 | 
					        if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache,
 | 
				
			||||||
 | 
					                                        DynamicCompressCache)):
 | 
				
			||||||
            if hasattr(past_key_values, "_seen_tokens"):
 | 
					            if hasattr(past_key_values, "_seen_tokens"):
 | 
				
			||||||
                past_key_values._seen_tokens -= new_cache_size
 | 
					                past_key_values._seen_tokens -= new_cache_size
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                past_key_values.seen_tokens -= new_cache_size
 | 
					                past_key_values.seen_tokens -= new_cache_size
 | 
				
			||||||
 | 
					            if isinstance(past_key_values, DynamicCompressCache):
 | 
				
			||||||
 | 
					                past_key_values.real_kv_len -= new_cache_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for i, k in enumerate(past_key_values.key_cache):
 | 
					            for i, k in enumerate(past_key_values.key_cache):
 | 
				
			||||||
                past_key_values.key_cache[i] = k[:, :, :-new_cache_size, :]
 | 
					                past_key_values.key_cache[i] = k[:, :, :-new_cache_size, :]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue