fix dimension (#10097)
This commit is contained in:
parent
4b02ff188b
commit
33b9e7744d
1 changed files with 2 additions and 4 deletions
|
|
@ -123,7 +123,7 @@ def gptj_attention_forward(
|
|||
key = torch.cat([k_rot, k_pass], dim=-1)
|
||||
query = torch.cat([q_rot, q_pass], dim=-1)
|
||||
else:
|
||||
query, key = apply_rotary_pos_emb(query, k_rot, cos, sin, position_ids, "gptj")
|
||||
query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids, "gptj")
|
||||
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
|
||||
|
|
@ -139,8 +139,6 @@ def gptj_attention_forward(
|
|||
if layer_past is not None:
|
||||
cache_k = layer_past[0]
|
||||
cache_v = layer_past[1]
|
||||
cache_k = cache_k.permute(0, 2, 1, 3)
|
||||
cache_v = cache_v.permute(0, 2, 1, 3)
|
||||
past_length = cache_k.size(2)
|
||||
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
|
||||
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
||||
|
|
@ -170,7 +168,7 @@ def gptj_attention_forward(
|
|||
value = value_cache
|
||||
|
||||
if use_cache is True:
|
||||
present = (key.permute(0, 2, 1, 3), value.permute(0, 2, 1, 3))
|
||||
present = (key, value)
|
||||
else:
|
||||
present = None
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue