fix dimension (#10097)

This commit is contained in:
Jiao Wang 2024-02-05 15:07:38 -08:00 committed by GitHub
parent 4b02ff188b
commit 33b9e7744d

View file

@ -123,7 +123,7 @@ def gptj_attention_forward(
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
query, key = apply_rotary_pos_emb(query, k_rot, cos, sin, position_ids, "gptj")
query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids, "gptj")
batch_size, q_len, _ = hidden_states.size()
@ -139,8 +139,6 @@ def gptj_attention_forward(
if layer_past is not None:
cache_k = layer_past[0]
cache_v = layer_past[1]
cache_k = cache_k.permute(0, 2, 1, 3)
cache_v = cache_v.permute(0, 2, 1, 3)
past_length = cache_k.size(2)
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
@ -170,7 +168,7 @@ def gptj_attention_forward(
value = value_cache
if use_cache is True:
present = (key.permute(0, 2, 1, 3), value.permute(0, 2, 1, 3))
present = (key, value)
else:
present = None