diff --git a/python/llm/src/bigdl/llm/transformers/models/gptj.py b/python/llm/src/bigdl/llm/transformers/models/gptj.py index e904a520..6a4e0aff 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gptj.py +++ b/python/llm/src/bigdl/llm/transformers/models/gptj.py @@ -134,7 +134,7 @@ def gptj_attention_forward( device = hidden_states.device if layer_past is not None: - kv_seq_len += layer_past[0].size(-2) + kv_seq_len += layer_past[0].size(1) if layer_past is not None: cache_k = layer_past[0]