From 33b9e7744d2505b96dd4368270cf97b168d43bca Mon Sep 17 00:00:00 2001 From: Jiao Wang Date: Mon, 5 Feb 2024 15:07:38 -0800 Subject: [PATCH] fix dimension (#10097) --- python/llm/src/bigdl/llm/transformers/models/gptj.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/gptj.py b/python/llm/src/bigdl/llm/transformers/models/gptj.py index 219c345a..5e0622d4 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gptj.py +++ b/python/llm/src/bigdl/llm/transformers/models/gptj.py @@ -123,7 +123,7 @@ def gptj_attention_forward( key = torch.cat([k_rot, k_pass], dim=-1) query = torch.cat([q_rot, q_pass], dim=-1) else: - query, key = apply_rotary_pos_emb(query, k_rot, cos, sin, position_ids, "gptj") + query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids, "gptj") batch_size, q_len, _ = hidden_states.size() @@ -139,8 +139,6 @@ def gptj_attention_forward( if layer_past is not None: cache_k = layer_past[0] cache_v = layer_past[1] - cache_k = cache_k.permute(0, 2, 1, 3) - cache_v = cache_v.permute(0, 2, 1, 3) past_length = cache_k.size(2) if cache_k.stride()[1] < kv_seq_len * cache_k.size(3): new_cache_k, new_cache_v = extend_kv_cache(batch_size, @@ -170,7 +168,7 @@ def gptj_attention_forward( value = value_cache if use_cache is True: - present = (key.permute(0, 2, 1, 3), value.permute(0, 2, 1, 3)) + present = (key, value) else: present = None