fix qwen kv cache length (#9998)
This commit is contained in:
parent
762adc4f9d
commit
aae1870096
1 changed files with 1 additions and 2 deletions
|
|
@ -192,8 +192,7 @@ def qwen_attention_forward(
|
||||||
cache_k, cache_v = layer_past[0], layer_past[1]
|
cache_k, cache_v = layer_past[0], layer_past[1]
|
||||||
cache_k = cache_k.transpose(1, 2)
|
cache_k = cache_k.transpose(1, 2)
|
||||||
cache_v = cache_v.transpose(1, 2)
|
cache_v = cache_v.transpose(1, 2)
|
||||||
kv_seq_len += cache_k.shape[2]
|
if cache_k.stride(1) < kv_seq_len * cache_k.size(3):
|
||||||
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
|
|
||||||
# allocate new
|
# allocate new
|
||||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue