LLM: fix wrong length in gptj kv_cache optimization (#9210)

* fix wrong length in gptj kv cache

* update
This commit is contained in:
Ruonan Wang 2023-10-18 14:59:02 +08:00 committed by GitHub
parent 6dad8d16df
commit 3555ebc148

View file

@ -134,7 +134,7 @@ def gptj_attention_forward(
device = hidden_states.device
if layer_past is not None:
kv_seq_len += layer_past[0].size(-2)
kv_seq_len += layer_past[0].size(1)
if layer_past is not None:
cache_k = layer_past[0]