From a2be3d75016a40b4a71104ed9811368a0ec70fa0 Mon Sep 17 00:00:00 2001 From: "Huang, Xinshengzi" Date: Thu, 22 Aug 2024 15:11:55 +0800 Subject: [PATCH] add comment of compress kv in attention forward --- python/llm/src/ipex_llm/transformers/models/baichuan.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index af34b3a5..3944a948 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -307,13 +307,11 @@ def baichuan_attention_forward_7b( is_causal=True).to(hidden_states.dtype) elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states): import xe_addons + if use_compresskv: + attention_mask = get_compresskv_attn_mask(key_states, attention_mask) if use_quantize_kv: attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) - elif use_compresskv: - attention_mask = get_compresskv_attn_mask(key_states, attention_mask) - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) else: attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)