diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index 84416832..f8d7d186 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -271,14 +271,9 @@ def baichuan_attention_forward_7b( # IPEX-LLM OPT: kv cache and quantize kv use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states) - if use_quantize_kv or (not use_compresskv): - key_states, value_states = update_past_key_value( - past_key_value, key_states, value_states, - kv_seq_len, use_quantize_kv, device - ) - past_key_value = (key_states, value_states) if use_cache else None - else: + # [CompressKV] + if use_compresskv: enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, q_len) @@ -286,6 +281,12 @@ def baichuan_attention_forward_7b( key_states, value_states, self.layer_idx, query_states, attention_mask, 1, self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) + else: + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, device + ) + past_key_value = (key_states, value_states) if use_cache else None if self.training: warnings.warn("xops is not supported on Intel GPU, so just use normal implementation")