parent
6be70283b7
commit
16b2a418be
1 changed files with 9 additions and 4 deletions
|
|
@ -855,11 +855,13 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
current_key_states = repeat_kv(current_key_states, self.num_key_value_groups)
|
||||
current_value_states = repeat_kv(current_value_states, self.num_key_value_groups)
|
||||
|
||||
cache_position = None
|
||||
current_query_states = query_states[batch: batch + 1, :, :, :]
|
||||
attn_output, attn_weights = native_sdp(current_query_states,
|
||||
current_key_states,
|
||||
current_value_states,
|
||||
attention_mask[batch],
|
||||
cache_position,
|
||||
1,
|
||||
1,
|
||||
current_kv_len,
|
||||
|
|
@ -901,10 +903,12 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
if isinstance(attention_mask, list):
|
||||
# For decoding fast path
|
||||
attention_mask = attention_mask[0]
|
||||
cache_position = None
|
||||
attn_output, attn_weights = native_sdp(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
cache_position,
|
||||
bsz,
|
||||
q_len,
|
||||
kv_seq_len,
|
||||
|
|
@ -1445,7 +1449,7 @@ def llama_attention_forward_4_38_original(
|
|||
def native_sdp(query, key, value, attention_mask, cache_position,
|
||||
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
|
||||
if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions):
|
||||
return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
||||
return native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_position,
|
||||
bsz, q_len, kv_seq_len, head_dim, num_heads)
|
||||
else:
|
||||
attn_weights = torch.matmul(query.to(key.dtype),
|
||||
|
|
@ -1502,14 +1506,14 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_positio
|
|||
if cache_position is not None:
|
||||
# for transformers 4.38.0
|
||||
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights_split = attn_weights_split + causal_mask
|
||||
else:
|
||||
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||
if attention_mask.size() != attn_mask_size:
|
||||
invalidInputError(False,
|
||||
f"Attention mask should be of size {attn_mask_size}, "
|
||||
f"but is {attention_mask.size()}")
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights_split = attn_weights_split + attention_mask
|
||||
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
|
||||
attn_outputs.append(torch.matmul(attn_weights_split, v))
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
|
|
@ -1767,8 +1771,9 @@ def llama_attention_fast_forward(
|
|||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
cache_position = None
|
||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||
attention_mask,
|
||||
attention_mask, cache_position,
|
||||
bsz, q_len, kv_seq_len,
|
||||
self.head_dim, self.num_heads, output_attentions)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue