diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index d4e9e223..f01d0cd3 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -149,7 +149,7 @@ def llama_attention_forward_4_31( attention_dtype = original_dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value) + enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) qtype = getattr(self.q_proj, "qtype", None) is_q4_0 = qtype == SYM_INT4 no_tp = not self.config.pretraining_tp > 1 diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 502fabfa..f009d223 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -116,10 +116,11 @@ def is_enough_kv_cache_room_4_36(past_key_value, idx): past_key_value.key_cache[idx].size(3) -def is_enough_kv_cache_room_4_31(past_key_value): +def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1): # to determinate if is enough kv cache room in transformers between 4.31 and 4.35 return past_key_value is not None and \ - past_key_value[0].stride()[1] > past_key_value[0].size(2) * past_key_value[0].size(3) + past_key_value[0].stride()[1] > \ + (past_key_value[0].size(2) + seq_len - 1) * past_key_value[0].size(3) def use_flash_attention(query):