Fix compresskv with lookahead issue (#11767)
* fix compresskv + lookahead attn_mask qwen2 * support llama chatglm * support mistral & chatglm * address comments * revert run.py
This commit is contained in:
parent
f97a77ea4e
commit
841dbcdf3a
6 changed files with 37 additions and 15 deletions
|
|
@ -108,7 +108,10 @@ def chatglm2_model_forward(
|
|||
if past_key_values is None:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.int64, device=inputs_embeds.device)
|
||||
else:
|
||||
kv_length = past_key_values[0][0].size(0)
|
||||
if isinstance(past_key_values, DynamicCompressCache):
|
||||
kv_length = past_key_values.get_seq_length()
|
||||
else:
|
||||
kv_length = past_key_values[0][0].size(0)
|
||||
position_ids = torch.arange(kv_length, kv_length + seq_length,
|
||||
dtype=torch.int64, device=inputs_embeds.device)
|
||||
position_ids = position_ids.repeat(batch_size, 1)
|
||||
|
|
@ -300,6 +303,8 @@ def chatglm2_attention_forward(
|
|||
attn_weights = None
|
||||
if use_sdp(q_len, kv_seq_len, head_dim, query_states):
|
||||
import xe_addons
|
||||
if use_compresskv and attention_mask is not None:
|
||||
attention_mask = None
|
||||
if use_quantize_kv:
|
||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ import torch
|
|||
from typing import Optional, Tuple, Union
|
||||
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \
|
||||
use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36
|
||||
use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36, \
|
||||
get_compresskv_attn_mask
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
||||
from ipex_llm.transformers.models.chatglm2 import repeat_kv
|
||||
from ipex_llm.transformers.kv import DynamicCompressCache
|
||||
|
|
@ -79,7 +80,10 @@ def chatglm4_model_forward(
|
|||
if past_key_values is None:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.int64, device=inputs_embeds.device)
|
||||
else:
|
||||
kv_length = past_key_values[0][0].size(2)
|
||||
if isinstance(past_key_values, DynamicCompressCache):
|
||||
kv_length = past_key_values.get_seq_length()
|
||||
else:
|
||||
kv_length = past_key_values[0][0].size(2)
|
||||
position_ids = torch.arange(kv_length, kv_length + seq_length,
|
||||
dtype=torch.int64, device=inputs_embeds.device)
|
||||
position_ids = position_ids.repeat(batch_size, 1)
|
||||
|
|
@ -232,6 +236,8 @@ def chatglm4_attention_forward(
|
|||
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
|
||||
elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_states, self.training):
|
||||
import xe_addons
|
||||
if use_compresskv:
|
||||
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
||||
if use_quantize_kv:
|
||||
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,8 @@ import torch.nn.functional as F
|
|||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from ipex_llm.transformers.models.utils import SILU
|
||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||
restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv
|
||||
restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv, \
|
||||
get_compresskv_attn_mask
|
||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
|
|
@ -1547,9 +1548,10 @@ def llama_attention_forward_4_41_original(
|
|||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import xe_addons
|
||||
# [CompressKV]
|
||||
if use_compresskv:
|
||||
# [CompressKV] set attention_mask = None
|
||||
new_attention_mask = None
|
||||
new_attention_mask = get_compresskv_attn_mask(key_states,
|
||||
new_attention_mask)
|
||||
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
||||
new_attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
|
|
@ -2111,9 +2113,10 @@ def llama_attention_forward_4_38_original(
|
|||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import xe_addons
|
||||
# [CompressKV]
|
||||
if use_compresskv:
|
||||
# [CompressKV] set attention_mask = None
|
||||
new_attention_mask = None
|
||||
new_attention_mask = get_compresskv_attn_mask(key_states,
|
||||
new_attention_mask)
|
||||
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
||||
new_attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
|
|
|
|||
|
|
@ -46,7 +46,8 @@ from transformers.models.mistral.modeling_mistral import MistralModel
|
|||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||
restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv
|
||||
restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv, \
|
||||
get_compresskv_attn_mask
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
||||
apply_rotary_pos_emb_no_cache_xpu
|
||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||
|
|
@ -1097,9 +1098,9 @@ def mistral_attention_forward_4_36_original(
|
|||
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
# new fp16 sdp doesn't require repeat_kv
|
||||
import xe_addons
|
||||
# [CompressKV] set attention_mask = None
|
||||
# [CompressKV]
|
||||
if use_compresskv:
|
||||
attention_mask = None
|
||||
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
||||
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
|
|
@ -1348,9 +1349,9 @@ def mistral_attention_forward_4_39_original(
|
|||
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
# new fp16 sdp doesn't require repeat_kv
|
||||
import xe_addons
|
||||
# [CompressKV] set attention_mask = None
|
||||
# [CompressKV]
|
||||
if use_compresskv:
|
||||
attention_mask = None
|
||||
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
||||
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ from torch.nn.functional import scaled_dot_product_attention as sdpa
|
|||
from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
|
||||
should_use_compresskv, is_enough_kv_cache_room_4_36
|
||||
should_use_compresskv, is_enough_kv_cache_room_4_36, get_compresskv_attn_mask
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, DynamicCompressCache
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
|
|
@ -473,7 +473,7 @@ def qwen2_attention_forward(
|
|||
elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||
import xe_addons
|
||||
if use_compresskv:
|
||||
attention_mask = None
|
||||
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
||||
if isinstance(past_key_value, DynamicFp8Cache):
|
||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
|
|
|
|||
|
|
@ -497,6 +497,13 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int):
|
|||
return x.device.type == 'xpu' and use_compress_kv == "1"
|
||||
|
||||
|
||||
def get_compresskv_attn_mask(key_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor):
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, -key_states.size(2):]
|
||||
return attention_mask
|
||||
|
||||
|
||||
def get_q_proj_or_qkv_proj(self):
|
||||
if hasattr(self, "q_proj"):
|
||||
proj = self.q_proj
|
||||
|
|
|
|||
Loading…
Reference in a new issue