fix dimension (#10097)

This commit is contained in:
Jiao Wang 2024-02-05 15:07:38 -08:00 committed by GitHub
parent 4b02ff188b
commit 33b9e7744d

View file

@ -123,7 +123,7 @@ def gptj_attention_forward(
key = torch.cat([k_rot, k_pass], dim=-1) key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1) query = torch.cat([q_rot, q_pass], dim=-1)
else: else:
query, key = apply_rotary_pos_emb(query, k_rot, cos, sin, position_ids, "gptj") query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids, "gptj")
batch_size, q_len, _ = hidden_states.size() batch_size, q_len, _ = hidden_states.size()
@ -139,8 +139,6 @@ def gptj_attention_forward(
if layer_past is not None: if layer_past is not None:
cache_k = layer_past[0] cache_k = layer_past[0]
cache_v = layer_past[1] cache_v = layer_past[1]
cache_k = cache_k.permute(0, 2, 1, 3)
cache_v = cache_v.permute(0, 2, 1, 3)
past_length = cache_k.size(2) past_length = cache_k.size(2)
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3): if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
new_cache_k, new_cache_v = extend_kv_cache(batch_size, new_cache_k, new_cache_v = extend_kv_cache(batch_size,
@ -170,7 +168,7 @@ def gptj_attention_forward(
value = value_cache value = value_cache
if use_cache is True: if use_cache is True:
present = (key.permute(0, 2, 1, 3), value.permute(0, 2, 1, 3)) present = (key, value)
else: else:
present = None present = None