LLM: fix wrong length in gptj kv_cache optimization (#9210)
* fix wrong length in gptj kv cache * update
This commit is contained in:
parent
6dad8d16df
commit
3555ebc148
1 changed files with 1 additions and 1 deletions
|
|
@ -134,7 +134,7 @@ def gptj_attention_forward(
|
|||
device = hidden_states.device
|
||||
|
||||
if layer_past is not None:
|
||||
kv_seq_len += layer_past[0].size(-2)
|
||||
kv_seq_len += layer_past[0].size(1)
|
||||
|
||||
if layer_past is not None:
|
||||
cache_k = layer_past[0]
|
||||
|
|
|
|||
Loading…
Reference in a new issue