LLM: update split qkv native sdp. (#10895)

* LLM: update split qkv native sdp.

* fix typo.
This commit is contained in:
Cengguang Zhang 2024-04-26 18:47:35 +08:00 committed by GitHub
parent 990535b1cf
commit 9752ffe979
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 16 deletions

View file

@ -258,16 +258,14 @@ def chatglm2_quantized_attention_forward_8eb45c(
query_split = torch.split(query_layer, block_size, dim=1) query_split = torch.split(query_layer, block_size, dim=1)
key_split = torch.split(key, block_size, dim=1) key_split = torch.split(key, block_size, dim=1)
value_split = torch.split(value, block_size, dim=1) value_split = torch.split(value, block_size, dim=1)
context_layer = torch.empty(batch_size, n_head, seq_len, results = []
head_dim, dtype=key.dtype).to(query_layer.device)
idx = 0
for q, k, v in zip(query_split, key_split, value_split): for q, k, v in zip(query_split, key_split, value_split):
if attention_mask is None: if attention_mask is None:
result = F.scaled_dot_product_attention(q, k, v, is_causal=True) result = F.scaled_dot_product_attention(q, k, v, is_causal=True)
else: else:
result = F.scaled_dot_product_attention(q, k, v, attention_mask) result = F.scaled_dot_product_attention(q, k, v, attention_mask)
context_layer[:, idx:idx+q.shape[1], :, :] = result results.append(result)
idx = idx + q.shape[1] context_layer = torch.cat(results, dim=1)
else: else:
if attention_mask is None: if attention_mask is None:
context_layer = F.scaled_dot_product_attention(query_layer, key, context_layer = F.scaled_dot_product_attention(query_layer, key,
@ -541,14 +539,11 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1) query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1)
key_split = torch.split(key_layer, block_size, dim=1) key_split = torch.split(key_layer, block_size, dim=1)
value_split = torch.split(value_layer, block_size, dim=1) value_split = torch.split(value_layer, block_size, dim=1)
batch_size, n_head, seq_len, head_dim = query_layer.shape results = []
context_layer = torch.empty(batch_size, n_head, seq_len,
head_dim, dtype=key_layer.dtype).to(query_layer.device)
idx = 0
for q, k, v in zip(query_split, key_split, value_split): for q, k, v in zip(query_split, key_split, value_split):
result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype) result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
context_layer[:, idx:idx+q.shape[1], :, :] = result results.append(result)
idx = idx + q.shape[1] context_layer = torch.cat(results, dim=1)
else: else:
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype), context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
key_layer, key_layer,

View file

@ -1423,8 +1423,7 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
query_split = torch.split(query.to(key.dtype), block_size, dim=1) query_split = torch.split(query.to(key.dtype), block_size, dim=1)
key_split = torch.split(key.transpose(2, 3), block_size, dim=1) key_split = torch.split(key.transpose(2, 3), block_size, dim=1)
value_split = torch.split(value, block_size, dim=1) value_split = torch.split(value, block_size, dim=1)
attn_output = torch.empty(bsz, num_heads, q_len, head_dim).to(query.device) attn_outputs = []
idx = 0
for q, k, v in zip(query_split, key_split, value_split): for q, k, v in zip(query_split, key_split, value_split):
attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim) attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim)
block_actual_size = attn_weights_split.size(1) block_actual_size = attn_weights_split.size(1)
@ -1442,9 +1441,8 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
f"but is {attention_mask.size()}") f"but is {attention_mask.size()}")
attn_weights_split = attn_weights_split + attention_mask attn_weights_split = attn_weights_split + attention_mask
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1) attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
attn_weights_split = torch.matmul(attn_weights_split, v) attn_outputs.append(torch.matmul(attn_weights_split, v))
attn_output[:, idx:idx+block_actual_size, :, :] = attn_weights_split attn_output = torch.cat(attn_outputs, dim=1)
idx = idx + block_actual_size
return attn_output.to(key.dtype), None return attn_output.to(key.dtype), None