fix dimension (#10097)
This commit is contained in:
parent
4b02ff188b
commit
33b9e7744d
1 changed files with 2 additions and 4 deletions
|
|
@ -123,7 +123,7 @@ def gptj_attention_forward(
|
||||||
key = torch.cat([k_rot, k_pass], dim=-1)
|
key = torch.cat([k_rot, k_pass], dim=-1)
|
||||||
query = torch.cat([q_rot, q_pass], dim=-1)
|
query = torch.cat([q_rot, q_pass], dim=-1)
|
||||||
else:
|
else:
|
||||||
query, key = apply_rotary_pos_emb(query, k_rot, cos, sin, position_ids, "gptj")
|
query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids, "gptj")
|
||||||
|
|
||||||
batch_size, q_len, _ = hidden_states.size()
|
batch_size, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
|
@ -139,8 +139,6 @@ def gptj_attention_forward(
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
cache_k = layer_past[0]
|
cache_k = layer_past[0]
|
||||||
cache_v = layer_past[1]
|
cache_v = layer_past[1]
|
||||||
cache_k = cache_k.permute(0, 2, 1, 3)
|
|
||||||
cache_v = cache_v.permute(0, 2, 1, 3)
|
|
||||||
past_length = cache_k.size(2)
|
past_length = cache_k.size(2)
|
||||||
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
|
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
|
||||||
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
||||||
|
|
@ -170,7 +168,7 @@ def gptj_attention_forward(
|
||||||
value = value_cache
|
value = value_cache
|
||||||
|
|
||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
present = (key.permute(0, 2, 1, 3), value.permute(0, 2, 1, 3))
|
present = (key, value)
|
||||||
else:
|
else:
|
||||||
present = None
|
present = None
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue