diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index a080847b..64f06626 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -192,8 +192,7 @@ def qwen_attention_forward( cache_k, cache_v = layer_past[0], layer_past[1] cache_k = cache_k.transpose(1, 2) cache_v = cache_v.transpose(1, 2) - kv_seq_len += cache_k.shape[2] - if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): + if cache_k.stride(1) < kv_seq_len * cache_k.size(3): # allocate new new_cache_k, new_cache_v = extend_kv_cache(bsz, self.num_heads,