From 7c43ac0164fa2ed0acc2db00c4aaa361f1179edd Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Mon, 8 Apr 2024 17:48:11 +0800 Subject: [PATCH] LLM: optimize llama natvie sdp for split qkv tensor (#10693) * LLM: optimize llama natvie sdp for split qkv tensor. * fix block real size. * fix comment. * fix style. * refactor. --- .../src/ipex_llm/transformers/models/llama.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 56ac4e8a..ee367131 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1331,7 +1331,7 @@ def native_sdp(query, key, value, attention_mask, bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions): if should_split_qkv_tensor(query, output_attentions): return native_sdp_split_qkv_tensor(query, key, value, attention_mask, - bsz, q_len, kv_seq_len, head_dim) + bsz, q_len, kv_seq_len, head_dim, num_heads) else: attn_weights = torch.matmul(query.to(key.dtype), key.transpose(2, 3)) / math.sqrt(head_dim) @@ -1362,14 +1362,17 @@ def native_sdp(query, key, value, attention_mask, def native_sdp_split_qkv_tensor(query, key, value, attention_mask, - bsz, q_len, kv_seq_len, head_dim): - query_split = torch.split(query.to(key.dtype), 16, dim=1) - key_split = torch.split(key.transpose(2, 3), 16, dim=1) - value_split = torch.split(value, 16, dim=1) - attn_outputs = [] + bsz, q_len, kv_seq_len, head_dim, num_heads): + block_size = 8 + query_split = torch.split(query.to(key.dtype), block_size, dim=1) + key_split = torch.split(key.transpose(2, 3), block_size, dim=1) + value_split = torch.split(value, block_size, dim=1) + attn_output = torch.empty(bsz, num_heads, q_len, head_dim).to(query.device) + idx = 0 for q, k, v in zip(query_split, key_split, value_split): attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim) - attn_weights_split_size = (bsz, 16, q_len, kv_seq_len) + block_actual_size = attn_weights_split.size(1) + attn_weights_split_size = (bsz, block_actual_size, q_len, kv_seq_len) if attn_weights_split.size() != attn_weights_split_size: invalidInputError(False, f"Splitted attention weights should be of size " @@ -1384,8 +1387,8 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask, attn_weights_split = attn_weights_split + attention_mask attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1) attn_weights_split = torch.matmul(attn_weights_split, v) - attn_outputs.append(attn_weights_split) - attn_output = torch.cat(attn_outputs, dim=1) + attn_output[:, idx:idx+block_actual_size, :, :] = attn_weights_split + idx = idx + block_actual_size return attn_output, None