LLM: optimize llama natvie sdp for split qkv tensor (#10693)
* LLM: optimize llama natvie sdp for split qkv tensor. * fix block real size. * fix comment. * fix style. * refactor.
This commit is contained in:
parent
1274cba79b
commit
7c43ac0164
1 changed files with 12 additions and 9 deletions
|
|
@ -1331,7 +1331,7 @@ def native_sdp(query, key, value, attention_mask,
|
||||||
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
|
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
|
||||||
if should_split_qkv_tensor(query, output_attentions):
|
if should_split_qkv_tensor(query, output_attentions):
|
||||||
return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
||||||
bsz, q_len, kv_seq_len, head_dim)
|
bsz, q_len, kv_seq_len, head_dim, num_heads)
|
||||||
else:
|
else:
|
||||||
attn_weights = torch.matmul(query.to(key.dtype),
|
attn_weights = torch.matmul(query.to(key.dtype),
|
||||||
key.transpose(2, 3)) / math.sqrt(head_dim)
|
key.transpose(2, 3)) / math.sqrt(head_dim)
|
||||||
|
|
@ -1362,14 +1362,17 @@ def native_sdp(query, key, value, attention_mask,
|
||||||
|
|
||||||
|
|
||||||
def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
||||||
bsz, q_len, kv_seq_len, head_dim):
|
bsz, q_len, kv_seq_len, head_dim, num_heads):
|
||||||
query_split = torch.split(query.to(key.dtype), 16, dim=1)
|
block_size = 8
|
||||||
key_split = torch.split(key.transpose(2, 3), 16, dim=1)
|
query_split = torch.split(query.to(key.dtype), block_size, dim=1)
|
||||||
value_split = torch.split(value, 16, dim=1)
|
key_split = torch.split(key.transpose(2, 3), block_size, dim=1)
|
||||||
attn_outputs = []
|
value_split = torch.split(value, block_size, dim=1)
|
||||||
|
attn_output = torch.empty(bsz, num_heads, q_len, head_dim).to(query.device)
|
||||||
|
idx = 0
|
||||||
for q, k, v in zip(query_split, key_split, value_split):
|
for q, k, v in zip(query_split, key_split, value_split):
|
||||||
attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim)
|
attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim)
|
||||||
attn_weights_split_size = (bsz, 16, q_len, kv_seq_len)
|
block_actual_size = attn_weights_split.size(1)
|
||||||
|
attn_weights_split_size = (bsz, block_actual_size, q_len, kv_seq_len)
|
||||||
if attn_weights_split.size() != attn_weights_split_size:
|
if attn_weights_split.size() != attn_weights_split_size:
|
||||||
invalidInputError(False,
|
invalidInputError(False,
|
||||||
f"Splitted attention weights should be of size "
|
f"Splitted attention weights should be of size "
|
||||||
|
|
@ -1384,8 +1387,8 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
||||||
attn_weights_split = attn_weights_split + attention_mask
|
attn_weights_split = attn_weights_split + attention_mask
|
||||||
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
|
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
|
||||||
attn_weights_split = torch.matmul(attn_weights_split, v)
|
attn_weights_split = torch.matmul(attn_weights_split, v)
|
||||||
attn_outputs.append(attn_weights_split)
|
attn_output[:, idx:idx+block_actual_size, :, :] = attn_weights_split
|
||||||
attn_output = torch.cat(attn_outputs, dim=1)
|
idx = idx + block_actual_size
|
||||||
return attn_output, None
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue