Fix gptj failed to extend (#10269)

This commit is contained in:
Yina Chen 2024-02-29 09:39:27 +08:00 committed by GitHub
parent 1572b6f7c3
commit 07f36fbfcc

View file

@ -134,7 +134,7 @@ def gptj_attention_forward(
device = hidden_states.device device = hidden_states.device
if layer_past is not None: if layer_past is not None:
kv_seq_len += layer_past[0].size(1) kv_seq_len += layer_past[0].size(2)
if layer_past is not None: if layer_past is not None:
cache_k = layer_past[0] cache_k = layer_past[0]