From 07f36fbfcc1b8960e68c4541c2bbe12cc75f0603 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Thu, 29 Feb 2024 09:39:27 +0800 Subject: [PATCH] Fix gptj failed to extend (#10269) --- python/llm/src/bigdl/llm/transformers/models/gptj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/gptj.py b/python/llm/src/bigdl/llm/transformers/models/gptj.py index 9c872fe7..794cf291 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gptj.py +++ b/python/llm/src/bigdl/llm/transformers/models/gptj.py @@ -134,7 +134,7 @@ def gptj_attention_forward( device = hidden_states.device if layer_past is not None: - kv_seq_len += layer_past[0].size(1) + kv_seq_len += layer_past[0].size(2) if layer_past is not None: cache_k = layer_past[0]