Fix gptj failed to extend (#10269)
This commit is contained in:
parent
1572b6f7c3
commit
07f36fbfcc
1 changed files with 1 additions and 1 deletions
|
|
@ -134,7 +134,7 @@ def gptj_attention_forward(
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
kv_seq_len += layer_past[0].size(1)
|
kv_seq_len += layer_past[0].size(2)
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
cache_k = layer_past[0]
|
cache_k = layer_past[0]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue