Qwen support compress kv (#11680)
* Qwen support compress kv * fix style * fix
This commit is contained in:
		
							parent
							
								
									9b36877897
								
							
						
					
					
						commit
						670ad887fc
					
				
					 4 changed files with 56 additions and 18 deletions
				
			
		| 
						 | 
					@ -259,7 +259,28 @@ class DynamicCompressCache(DynamicCache):
 | 
				
			||||||
                num_key_value_groups=num_key_value_groups)
 | 
					                num_key_value_groups=num_key_value_groups)
 | 
				
			||||||
            self.key_cache.append(key_states_compress)
 | 
					            self.key_cache.append(key_states_compress)
 | 
				
			||||||
            self.value_cache.append(value_states_compress)
 | 
					            self.value_cache.append(value_states_compress)
 | 
				
			||||||
            return key_states, value_states
 | 
					
 | 
				
			||||||
 | 
					            k_cache_compressed, v_cache_compressed = init_kv_cache(
 | 
				
			||||||
 | 
					                bsz, num_heads, head_dim,
 | 
				
			||||||
 | 
					                0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
				
			||||||
 | 
					                key_states.dtype, key_states.device
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            k_cache_compressed, v_cache_compressed = append_kv_cache(
 | 
				
			||||||
 | 
					                k_cache_compressed, v_cache_compressed,
 | 
				
			||||||
 | 
					                key_states_compress, value_states_compress)
 | 
				
			||||||
 | 
					            self.key_cache[layer_idx] = k_cache_compressed
 | 
				
			||||||
 | 
					            self.value_cache[layer_idx] = v_cache_compressed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if key_states.stride(2) != head_dim:
 | 
				
			||||||
 | 
					                k_cache, v_cache = init_kv_cache(
 | 
				
			||||||
 | 
					                    bsz, num_heads, head_dim,
 | 
				
			||||||
 | 
					                    0, key_states.size(2),
 | 
				
			||||||
 | 
					                    key_states.dtype, key_states.device
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
				
			||||||
 | 
					                return k_cache, v_cache
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                return key_states, value_states
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            cache_k = self.key_cache[layer_idx]
 | 
					            cache_k = self.key_cache[layer_idx]
 | 
				
			||||||
            cache_v = self.value_cache[layer_idx]
 | 
					            cache_v = self.value_cache[layer_idx]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1289,7 +1289,7 @@ def llama_attention_forward_4_41_original(
 | 
				
			||||||
    # for flash attention
 | 
					    # for flash attention
 | 
				
			||||||
    original_dtype = hidden_states.dtype
 | 
					    original_dtype = hidden_states.dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # [SnapKV]
 | 
					    # [CompressKV]
 | 
				
			||||||
    use_compresskv = should_use_compresskv(hidden_states)
 | 
					    use_compresskv = should_use_compresskv(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
					    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
| 
						 | 
					@ -1324,7 +1324,7 @@ def llama_attention_forward_4_41_original(
 | 
				
			||||||
                                                                       self.rotary_emb.base,)
 | 
					                                                                       self.rotary_emb.base,)
 | 
				
			||||||
        kv_seq_len += 1
 | 
					        kv_seq_len += 1
 | 
				
			||||||
        # update past_key_value's seem_tokens and kv caches.
 | 
					        # update past_key_value's seem_tokens and kv caches.
 | 
				
			||||||
        # [SnapKV]
 | 
					        # [CompressKV]
 | 
				
			||||||
        if use_compresskv:
 | 
					        if use_compresskv:
 | 
				
			||||||
            past_key_value.update_seen_tokens(self.layer_idx, q_len)
 | 
					            past_key_value.update_seen_tokens(self.layer_idx, q_len)
 | 
				
			||||||
            kv_seq_len = past_key_value.get_seq_length()
 | 
					            kv_seq_len = past_key_value.get_seq_length()
 | 
				
			||||||
| 
						 | 
					@ -1496,7 +1496,7 @@ def llama_attention_forward_4_41_original(
 | 
				
			||||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
					            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
				
			||||||
        import xe_addons
 | 
					        import xe_addons
 | 
				
			||||||
        if use_compresskv:
 | 
					        if use_compresskv:
 | 
				
			||||||
            # [SnapKV] set attention_mask = None
 | 
					            # [CompressKV] set attention_mask = None
 | 
				
			||||||
            new_attention_mask = None
 | 
					            new_attention_mask = None
 | 
				
			||||||
        attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
					        attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
				
			||||||
                                    new_attention_mask)
 | 
					                                    new_attention_mask)
 | 
				
			||||||
| 
						 | 
					@ -1833,7 +1833,7 @@ def llama_attention_forward_4_38_original(
 | 
				
			||||||
    # for flash attention
 | 
					    # for flash attention
 | 
				
			||||||
    original_dtype = hidden_states.dtype
 | 
					    original_dtype = hidden_states.dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # [SnapKV]
 | 
					    # [CompressKV]
 | 
				
			||||||
    use_compresskv = should_use_compresskv(hidden_states)
 | 
					    use_compresskv = should_use_compresskv(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
					    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
| 
						 | 
					@ -1868,7 +1868,7 @@ def llama_attention_forward_4_38_original(
 | 
				
			||||||
                                                                       self.rotary_emb.base,)
 | 
					                                                                       self.rotary_emb.base,)
 | 
				
			||||||
        kv_seq_len += 1
 | 
					        kv_seq_len += 1
 | 
				
			||||||
        # update past_key_value's seem_tokens and kv caches.
 | 
					        # update past_key_value's seem_tokens and kv caches.
 | 
				
			||||||
        # [SnapKV]
 | 
					        # [CompressKV]
 | 
				
			||||||
        if use_compresskv:
 | 
					        if use_compresskv:
 | 
				
			||||||
            past_key_value.update_seen_tokens(self.layer_idx, q_len)
 | 
					            past_key_value.update_seen_tokens(self.layer_idx, q_len)
 | 
				
			||||||
            kv_seq_len = past_key_value.get_seq_length()
 | 
					            kv_seq_len = past_key_value.get_seq_length()
 | 
				
			||||||
| 
						 | 
					@ -2039,7 +2039,7 @@ def llama_attention_forward_4_38_original(
 | 
				
			||||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
					            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
				
			||||||
        import xe_addons
 | 
					        import xe_addons
 | 
				
			||||||
        if use_compresskv:
 | 
					        if use_compresskv:
 | 
				
			||||||
            # [SnapKV] set attention_mask = None
 | 
					            # [CompressKV] set attention_mask = None
 | 
				
			||||||
            new_attention_mask = None
 | 
					            new_attention_mask = None
 | 
				
			||||||
        attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
					        attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
				
			||||||
                                    new_attention_mask)
 | 
					                                    new_attention_mask)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -897,7 +897,7 @@ def mistral_attention_forward_4_36_original(
 | 
				
			||||||
    # for flash attention
 | 
					    # for flash attention
 | 
				
			||||||
    original_dtype = hidden_states.dtype
 | 
					    original_dtype = hidden_states.dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # [SnapKV]
 | 
					    # [CompressKV]
 | 
				
			||||||
    use_compresskv = should_use_compresskv(hidden_states)
 | 
					    use_compresskv = should_use_compresskv(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
					    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
| 
						 | 
					@ -930,7 +930,7 @@ def mistral_attention_forward_4_36_original(
 | 
				
			||||||
        kv_seq_len += 1
 | 
					        kv_seq_len += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # update past_key_value's seem_tokens and kv caches.
 | 
					        # update past_key_value's seem_tokens and kv caches.
 | 
				
			||||||
        # [SnapKV]
 | 
					        # [CompressKV]
 | 
				
			||||||
        if use_compresskv:
 | 
					        if use_compresskv:
 | 
				
			||||||
            past_key_value.update_seen_tokens(self.layer_idx, q_len)
 | 
					            past_key_value.update_seen_tokens(self.layer_idx, q_len)
 | 
				
			||||||
            kv_seq_len = past_key_value.get_seq_length()
 | 
					            kv_seq_len = past_key_value.get_seq_length()
 | 
				
			||||||
| 
						 | 
					@ -1055,7 +1055,7 @@ def mistral_attention_forward_4_36_original(
 | 
				
			||||||
    elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
					    elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
				
			||||||
        # new fp16 sdp doesn't require repeat_kv
 | 
					        # new fp16 sdp doesn't require repeat_kv
 | 
				
			||||||
        import xe_addons
 | 
					        import xe_addons
 | 
				
			||||||
        # [SnapKV] set attention_mask = None
 | 
					        # [CompressKV] set attention_mask = None
 | 
				
			||||||
        if use_compresskv:
 | 
					        if use_compresskv:
 | 
				
			||||||
            attention_mask = None
 | 
					            attention_mask = None
 | 
				
			||||||
        attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
 | 
					        attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
 | 
				
			||||||
| 
						 | 
					@ -1142,7 +1142,7 @@ def mistral_attention_forward_4_39_original(
 | 
				
			||||||
    # for flash attention
 | 
					    # for flash attention
 | 
				
			||||||
    original_dtype = hidden_states.dtype
 | 
					    original_dtype = hidden_states.dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # [SnapKV]
 | 
					    # [CompressKV]
 | 
				
			||||||
    use_compresskv = should_use_compresskv(hidden_states)
 | 
					    use_compresskv = should_use_compresskv(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
					    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
| 
						 | 
					@ -1175,7 +1175,7 @@ def mistral_attention_forward_4_39_original(
 | 
				
			||||||
        kv_seq_len += 1
 | 
					        kv_seq_len += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # update past_key_value's seem_tokens and kv caches.
 | 
					        # update past_key_value's seem_tokens and kv caches.
 | 
				
			||||||
        # [SnapKV]
 | 
					        # [CompressKV]
 | 
				
			||||||
        if use_compresskv:
 | 
					        if use_compresskv:
 | 
				
			||||||
            past_key_value.update_seen_tokens(self.layer_idx, q_len)
 | 
					            past_key_value.update_seen_tokens(self.layer_idx, q_len)
 | 
				
			||||||
            kv_seq_len = past_key_value.get_seq_length()
 | 
					            kv_seq_len = past_key_value.get_seq_length()
 | 
				
			||||||
| 
						 | 
					@ -1300,7 +1300,7 @@ def mistral_attention_forward_4_39_original(
 | 
				
			||||||
    elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
					    elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
				
			||||||
        # new fp16 sdp doesn't require repeat_kv
 | 
					        # new fp16 sdp doesn't require repeat_kv
 | 
				
			||||||
        import xe_addons
 | 
					        import xe_addons
 | 
				
			||||||
        # [SnapKV] set attention_mask = None
 | 
					        # [CompressKV] set attention_mask = None
 | 
				
			||||||
        if use_compresskv:
 | 
					        if use_compresskv:
 | 
				
			||||||
            attention_mask = None
 | 
					            attention_mask = None
 | 
				
			||||||
        attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
 | 
					        attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -47,9 +47,10 @@ from torch.nn.functional import scaled_dot_product_attention as sdpa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
 | 
					from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
 | 
				
			||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
					from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
					from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
 | 
				
			||||||
 | 
					    should_use_compresskv, is_enough_kv_cache_room_4_36
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
 | 
					from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
 | 
				
			||||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
 | 
					from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, DynamicCompressCache
 | 
				
			||||||
from ipex_llm.utils.common import invalidInputError
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
 | 
					from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
 | 
				
			||||||
| 
						 | 
					@ -117,11 +118,16 @@ def qwen2_model_forward(
 | 
				
			||||||
        and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
 | 
					        and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
 | 
				
			||||||
                                  self.config.num_attention_heads//self.config.num_key_value_heads)
 | 
					                                  self.config.num_attention_heads//self.config.num_key_value_heads)
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					    use_compress_kv = should_use_compresskv(inputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if use_cache:
 | 
					    if use_cache:
 | 
				
			||||||
        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
					        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
				
			||||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
					            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
				
			||||||
        if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
 | 
					        elif use_compress_kv and not isinstance(past_key_values,
 | 
				
			||||||
 | 
					                                                DynamicCompressCache):
 | 
				
			||||||
 | 
					            past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
 | 
				
			||||||
 | 
					        if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
 | 
				
			||||||
 | 
					                                                                          DynamicNormalCache):
 | 
				
			||||||
            past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
 | 
					            past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
 | 
				
			||||||
        past_key_values_length = past_key_values.get_usable_length(seq_length)
 | 
					        past_key_values_length = past_key_values.get_usable_length(seq_length)
 | 
				
			||||||
    # ipex-llm changes end
 | 
					    # ipex-llm changes end
 | 
				
			||||||
| 
						 | 
					@ -394,6 +400,9 @@ def qwen2_attention_forward(
 | 
				
			||||||
    bsz, q_len, _ = hidden_states.size()
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
    device = hidden_states.device
 | 
					    device = hidden_states.device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # [CompressKV]
 | 
				
			||||||
 | 
					    use_compresskv = should_use_compresskv(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
 | 
					    if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
 | 
				
			||||||
        qkv = self.qkv_proj(hidden_states)
 | 
					        qkv = self.qkv_proj(hidden_states)
 | 
				
			||||||
        qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
 | 
					        qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
 | 
				
			||||||
| 
						 | 
					@ -427,8 +436,16 @@ def qwen2_attention_forward(
 | 
				
			||||||
                                                        cos, sin, position_ids)
 | 
					                                                        cos, sin, position_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if past_key_value is not None:
 | 
					    if past_key_value is not None:
 | 
				
			||||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					        # [CompressKV]
 | 
				
			||||||
                                                         self.layer_idx, None)
 | 
					        if use_compresskv:
 | 
				
			||||||
 | 
					            enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
				
			||||||
 | 
					            key_states, value_states = past_key_value.update(
 | 
				
			||||||
 | 
					                key_states, value_states, self.layer_idx,
 | 
				
			||||||
 | 
					                query_states, attention_mask, self.num_key_value_groups,
 | 
				
			||||||
 | 
					                self.config, enough_kv_room, 256)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            key_states, value_states = past_key_value.update(key_states, value_states,
 | 
				
			||||||
 | 
					                                                             self.layer_idx, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_weights = None
 | 
					    attn_weights = None
 | 
				
			||||||
    if query_states.device.type == "cpu":
 | 
					    if query_states.device.type == "cpu":
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue