diff --git a/python/llm/src/bigdl/llm/transformers/models/gptj.py b/python/llm/src/bigdl/llm/transformers/models/gptj.py index 219c345a..5e0622d4 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gptj.py +++ b/python/llm/src/bigdl/llm/transformers/models/gptj.py @@ -123,7 +123,7 @@ def gptj_attention_forward( key = torch.cat([k_rot, k_pass], dim=-1) query = torch.cat([q_rot, q_pass], dim=-1) else: - query, key = apply_rotary_pos_emb(query, k_rot, cos, sin, position_ids, "gptj") + query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids, "gptj") batch_size, q_len, _ = hidden_states.size() @@ -139,8 +139,6 @@ def gptj_attention_forward( if layer_past is not None: cache_k = layer_past[0] cache_v = layer_past[1] - cache_k = cache_k.permute(0, 2, 1, 3) - cache_v = cache_v.permute(0, 2, 1, 3) past_length = cache_k.size(2) if cache_k.stride()[1] < kv_seq_len * cache_k.size(3): new_cache_k, new_cache_v = extend_kv_cache(batch_size, @@ -170,7 +168,7 @@ def gptj_attention_forward( value = value_cache if use_cache is True: - present = (key.permute(0, 2, 1, 3), value.permute(0, 2, 1, 3)) + present = (key, value) else: present = None