Fix compresskv with lookahead issue (#11767)
* fix compresskv + lookahead attn_mask qwen2 * support llama chatglm * support mistral & chatglm * address comments * revert run.py
This commit is contained in:
		
							parent
							
								
									f97a77ea4e
								
							
						
					
					
						commit
						841dbcdf3a
					
				
					 6 changed files with 37 additions and 15 deletions
				
			
		| 
						 | 
				
			
			@ -108,7 +108,10 @@ def chatglm2_model_forward(
 | 
			
		|||
        if past_key_values is None:
 | 
			
		||||
            position_ids = torch.arange(seq_length, dtype=torch.int64, device=inputs_embeds.device)
 | 
			
		||||
        else:
 | 
			
		||||
            kv_length = past_key_values[0][0].size(0)
 | 
			
		||||
            if isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                kv_length = past_key_values.get_seq_length()
 | 
			
		||||
            else:
 | 
			
		||||
                kv_length = past_key_values[0][0].size(0)
 | 
			
		||||
            position_ids = torch.arange(kv_length, kv_length + seq_length,
 | 
			
		||||
                                        dtype=torch.int64, device=inputs_embeds.device)
 | 
			
		||||
        position_ids = position_ids.repeat(batch_size, 1)
 | 
			
		||||
| 
						 | 
				
			
			@ -300,6 +303,8 @@ def chatglm2_attention_forward(
 | 
			
		|||
    attn_weights = None
 | 
			
		||||
    if use_sdp(q_len, kv_seq_len, head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_compresskv and attention_mask is not None:
 | 
			
		||||
            attention_mask = None
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,7 +21,8 @@ import torch
 | 
			
		|||
from typing import Optional, Tuple, Union
 | 
			
		||||
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \
 | 
			
		||||
    use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36
 | 
			
		||||
    use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36, \
 | 
			
		||||
    get_compresskv_attn_mask
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
 | 
			
		||||
from ipex_llm.transformers.models.chatglm2 import repeat_kv
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicCompressCache
 | 
			
		||||
| 
						 | 
				
			
			@ -79,7 +80,10 @@ def chatglm4_model_forward(
 | 
			
		|||
        if past_key_values is None:
 | 
			
		||||
            position_ids = torch.arange(seq_length, dtype=torch.int64, device=inputs_embeds.device)
 | 
			
		||||
        else:
 | 
			
		||||
            kv_length = past_key_values[0][0].size(2)
 | 
			
		||||
            if isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                kv_length = past_key_values.get_seq_length()
 | 
			
		||||
            else:
 | 
			
		||||
                kv_length = past_key_values[0][0].size(2)
 | 
			
		||||
            position_ids = torch.arange(kv_length, kv_length + seq_length,
 | 
			
		||||
                                        dtype=torch.int64, device=inputs_embeds.device)
 | 
			
		||||
        position_ids = position_ids.repeat(batch_size, 1)
 | 
			
		||||
| 
						 | 
				
			
			@ -232,6 +236,8 @@ def chatglm4_attention_forward(
 | 
			
		|||
            attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
 | 
			
		||||
        if use_quantize_kv:
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states,
 | 
			
		||||
                                                   attention_mask)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -42,7 +42,8 @@ import torch.nn.functional as F
 | 
			
		|||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import SILU
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
 | 
			
		||||
    restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv
 | 
			
		||||
    restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv, \
 | 
			
		||||
    get_compresskv_attn_mask
 | 
			
		||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
			
		||||
    apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
| 
						 | 
				
			
			@ -1547,9 +1548,10 @@ def llama_attention_forward_4_41_original(
 | 
			
		|||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        # [CompressKV]
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            # [CompressKV] set attention_mask = None
 | 
			
		||||
            new_attention_mask = None
 | 
			
		||||
            new_attention_mask = get_compresskv_attn_mask(key_states,
 | 
			
		||||
                                                          new_attention_mask)
 | 
			
		||||
        attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
			
		||||
                                    new_attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
| 
						 | 
				
			
			@ -2111,9 +2113,10 @@ def llama_attention_forward_4_38_original(
 | 
			
		|||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        # [CompressKV]
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            # [CompressKV] set attention_mask = None
 | 
			
		||||
            new_attention_mask = None
 | 
			
		||||
            new_attention_mask = get_compresskv_attn_mask(key_states,
 | 
			
		||||
                                                          new_attention_mask)
 | 
			
		||||
        attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
			
		||||
                                    new_attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,7 +46,8 @@ from transformers.models.mistral.modeling_mistral import MistralModel
 | 
			
		|||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
 | 
			
		||||
    restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv
 | 
			
		||||
    restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv, \
 | 
			
		||||
    get_compresskv_attn_mask
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
 | 
			
		||||
    apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
			
		||||
| 
						 | 
				
			
			@ -1097,9 +1098,9 @@ def mistral_attention_forward_4_36_original(
 | 
			
		|||
    elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        # new fp16 sdp doesn't require repeat_kv
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        # [CompressKV] set attention_mask = None
 | 
			
		||||
        # [CompressKV]
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            attention_mask = None
 | 
			
		||||
            attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
 | 
			
		||||
        attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
| 
						 | 
				
			
			@ -1348,9 +1349,9 @@ def mistral_attention_forward_4_39_original(
 | 
			
		|||
    elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        # new fp16 sdp doesn't require repeat_kv
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        # [CompressKV] set attention_mask = None
 | 
			
		||||
        # [CompressKV]
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            attention_mask = None
 | 
			
		||||
            attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
 | 
			
		||||
        attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -48,7 +48,7 @@ from torch.nn.functional import scaled_dot_product_attention as sdpa
 | 
			
		|||
from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
 | 
			
		||||
    should_use_compresskv, is_enough_kv_cache_room_4_36
 | 
			
		||||
    should_use_compresskv, is_enough_kv_cache_room_4_36, get_compresskv_attn_mask
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, DynamicCompressCache
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
| 
						 | 
				
			
			@ -473,7 +473,7 @@ def qwen2_attention_forward(
 | 
			
		|||
    elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            attention_mask = None
 | 
			
		||||
            attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
 | 
			
		||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
			
		||||
            attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
                                            attention_mask)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -497,6 +497,13 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int):
 | 
			
		|||
        return x.device.type == 'xpu' and use_compress_kv == "1"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_compresskv_attn_mask(key_states: torch.Tensor,
 | 
			
		||||
                             attention_mask: torch.Tensor):
 | 
			
		||||
    if attention_mask is not None:
 | 
			
		||||
        attention_mask = attention_mask[:, :, :, -key_states.size(2):]
 | 
			
		||||
    return attention_mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_q_proj_or_qkv_proj(self):
 | 
			
		||||
    if hasattr(self, "q_proj"):
 | 
			
		||||
        proj = self.q_proj
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue