diff --git a/python/llm/src/bigdl/llm/transformers/models/gptj.py b/python/llm/src/bigdl/llm/transformers/models/gptj.py index 9c872fe7..794cf291 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gptj.py +++ b/python/llm/src/bigdl/llm/transformers/models/gptj.py @@ -134,7 +134,7 @@ def gptj_attention_forward( device = hidden_states.device if layer_past is not None: - kv_seq_len += layer_past[0].size(1) + kv_seq_len += layer_past[0].size(2) if layer_past is not None: cache_k = layer_past[0]