fix qwen kv cache length (#9998)

This commit is contained in:
Yishuo Wang 2024-01-26 10:15:01 +08:00 committed by GitHub
parent 762adc4f9d
commit aae1870096

View file

@ -192,8 +192,7 @@ def qwen_attention_forward(
cache_k, cache_v = layer_past[0], layer_past[1]
cache_k = cache_k.transpose(1, 2)
cache_v = cache_v.transpose(1, 2)
kv_seq_len += cache_k.shape[2]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if cache_k.stride(1) < kv_seq_len * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads,