From 3555ebc1481b3af969ce3e26899c51183247f337 Mon Sep 17 00:00:00 2001 From: Ruonan Wang <105281011+rnwang04@users.noreply.github.com> Date: Wed, 18 Oct 2023 14:59:02 +0800 Subject: [PATCH] LLM: fix wrong length in gptj kv_cache optimization (#9210) * fix wrong length in gptj kv cache * update --- python/llm/src/bigdl/llm/transformers/models/gptj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/gptj.py b/python/llm/src/bigdl/llm/transformers/models/gptj.py index e904a520..6a4e0aff 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gptj.py +++ b/python/llm/src/bigdl/llm/transformers/models/gptj.py @@ -134,7 +134,7 @@ def gptj_attention_forward( device = hidden_states.device if layer_past is not None: - kv_seq_len += layer_past[0].size(-2) + kv_seq_len += layer_past[0].size(1) if layer_past is not None: cache_k = layer_past[0]