From f919f5792a5a48632a87c019325c5b3079938f4d Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Fri, 5 Jan 2024 12:38:57 +0800 Subject: [PATCH] fix kv cache out of bound (#9827) --- python/llm/src/bigdl/llm/transformers/models/llama.py | 2 +- python/llm/src/bigdl/llm/transformers/models/utils.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index d4e9e223..f01d0cd3 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -149,7 +149,7 @@ def llama_attention_forward_4_31( attention_dtype = original_dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value) + enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) qtype = getattr(self.q_proj, "qtype", None) is_q4_0 = qtype == SYM_INT4 no_tp = not self.config.pretraining_tp > 1 diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 502fabfa..f009d223 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -116,10 +116,11 @@ def is_enough_kv_cache_room_4_36(past_key_value, idx): past_key_value.key_cache[idx].size(3) -def is_enough_kv_cache_room_4_31(past_key_value): +def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1): # to determinate if is enough kv cache room in transformers between 4.31 and 4.35 return past_key_value is not None and \ - past_key_value[0].stride()[1] > past_key_value[0].size(2) * past_key_value[0].size(3) + past_key_value[0].stride()[1] > \ + (past_key_value[0].size(2) + seq_len - 1) * past_key_value[0].size(3) def use_flash_attention(query):