From aae1870096e7cb432ccc74a503a7e0f3529ff985 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 26 Jan 2024 10:15:01 +0800 Subject: [PATCH] fix qwen kv cache length (#9998) --- python/llm/src/bigdl/llm/transformers/models/qwen.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index a080847b..64f06626 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -192,8 +192,7 @@ def qwen_attention_forward( cache_k, cache_v = layer_past[0], layer_past[1] cache_k = cache_k.transpose(1, 2) cache_v = cache_v.transpose(1, 2) - kv_seq_len += cache_k.shape[2] - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if cache_k.stride(1) < kv_seq_len * cache_k.size(3): # allocate new new_cache_k, new_cache_v = extend_kv_cache(bsz, self.num_heads,