parent
6be70283b7
commit
16b2a418be
1 changed files with 9 additions and 4 deletions
|
|
@ -855,11 +855,13 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
current_key_states = repeat_kv(current_key_states, self.num_key_value_groups)
|
current_key_states = repeat_kv(current_key_states, self.num_key_value_groups)
|
||||||
current_value_states = repeat_kv(current_value_states, self.num_key_value_groups)
|
current_value_states = repeat_kv(current_value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
cache_position = None
|
||||||
current_query_states = query_states[batch: batch + 1, :, :, :]
|
current_query_states = query_states[batch: batch + 1, :, :, :]
|
||||||
attn_output, attn_weights = native_sdp(current_query_states,
|
attn_output, attn_weights = native_sdp(current_query_states,
|
||||||
current_key_states,
|
current_key_states,
|
||||||
current_value_states,
|
current_value_states,
|
||||||
attention_mask[batch],
|
attention_mask[batch],
|
||||||
|
cache_position,
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
current_kv_len,
|
current_kv_len,
|
||||||
|
|
@ -901,10 +903,12 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
if isinstance(attention_mask, list):
|
if isinstance(attention_mask, list):
|
||||||
# For decoding fast path
|
# For decoding fast path
|
||||||
attention_mask = attention_mask[0]
|
attention_mask = attention_mask[0]
|
||||||
|
cache_position = None
|
||||||
attn_output, attn_weights = native_sdp(query_states,
|
attn_output, attn_weights = native_sdp(query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
cache_position,
|
||||||
bsz,
|
bsz,
|
||||||
q_len,
|
q_len,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
|
|
@ -1445,7 +1449,7 @@ def llama_attention_forward_4_38_original(
|
||||||
def native_sdp(query, key, value, attention_mask, cache_position,
|
def native_sdp(query, key, value, attention_mask, cache_position,
|
||||||
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
|
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
|
||||||
if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions):
|
if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions):
|
||||||
return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
return native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_position,
|
||||||
bsz, q_len, kv_seq_len, head_dim, num_heads)
|
bsz, q_len, kv_seq_len, head_dim, num_heads)
|
||||||
else:
|
else:
|
||||||
attn_weights = torch.matmul(query.to(key.dtype),
|
attn_weights = torch.matmul(query.to(key.dtype),
|
||||||
|
|
@ -1502,14 +1506,14 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_positio
|
||||||
if cache_position is not None:
|
if cache_position is not None:
|
||||||
# for transformers 4.38.0
|
# for transformers 4.38.0
|
||||||
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
|
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
|
||||||
attn_weights = attn_weights + causal_mask
|
attn_weights_split = attn_weights_split + causal_mask
|
||||||
else:
|
else:
|
||||||
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||||
if attention_mask.size() != attn_mask_size:
|
if attention_mask.size() != attn_mask_size:
|
||||||
invalidInputError(False,
|
invalidInputError(False,
|
||||||
f"Attention mask should be of size {attn_mask_size}, "
|
f"Attention mask should be of size {attn_mask_size}, "
|
||||||
f"but is {attention_mask.size()}")
|
f"but is {attention_mask.size()}")
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights_split = attn_weights_split + attention_mask
|
||||||
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
|
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
|
||||||
attn_outputs.append(torch.matmul(attn_weights_split, v))
|
attn_outputs.append(torch.matmul(attn_weights_split, v))
|
||||||
attn_output = torch.cat(attn_outputs, dim=1)
|
attn_output = torch.cat(attn_outputs, dim=1)
|
||||||
|
|
@ -1767,8 +1771,9 @@ def llama_attention_fast_forward(
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
cache_position = None
|
||||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||||
attention_mask,
|
attention_mask, cache_position,
|
||||||
bsz, q_len, kv_seq_len,
|
bsz, q_len, kv_seq_len,
|
||||||
self.head_dim, self.num_heads, output_attentions)
|
self.head_dim, self.num_heads, output_attentions)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue