LLM: support llama2 8k input with w4a16. (#10677)
* LLM: support llama2 8k input with w4a16. * fix comment and style. * fix style. * fix comments and split tensor to quantized attention forward. * fix style. * refactor name. * fix style. * fix style. * fix style. * refactor checker name. * refactor native sdp split qkv tensor name. * fix style. * fix comment rename variables. * fix co-exist of intermedia results.
This commit is contained in:
parent
db7c5cb78f
commit
c0cd238e40
1 changed files with 76 additions and 34 deletions
|
|
@ -214,6 +214,15 @@ def should_use_fast_rope(self, query_states, position_ids):
|
||||||
return use_fuse_rope
|
return use_fuse_rope
|
||||||
|
|
||||||
|
|
||||||
|
def should_split_qkv_tensor(query_states, output_attentions):
|
||||||
|
if not output_attentions and query_states.dtype == torch.float16 and \
|
||||||
|
query_states.shape[2] >= 6800:
|
||||||
|
# split tensor for memory block limitation
|
||||||
|
# support fp16 and set input length threshold at 6800 for now
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def llama_decoder_forward(
|
def llama_decoder_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
@ -404,7 +413,7 @@ def llama_attention_forward_4_31_quantized(
|
||||||
attn_output, attn_weights = native_sdp(query_states, repeated_key_states,
|
attn_output, attn_weights = native_sdp(query_states, repeated_key_states,
|
||||||
repeated_value_states, attention_mask,
|
repeated_value_states, attention_mask,
|
||||||
bsz, q_len, kv_seq_len,
|
bsz, q_len, kv_seq_len,
|
||||||
self.head_dim, self.num_heads)
|
self.head_dim, self.num_heads, output_attentions)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
k_cache, v_cache = init_fp8_kv_cache(
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
bsz, self.num_key_value_heads, kv_seq_len, self.head_dim,
|
bsz, self.num_key_value_heads, kv_seq_len, self.head_dim,
|
||||||
|
|
@ -429,7 +438,7 @@ def llama_attention_forward_4_31_quantized(
|
||||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
bsz, q_len, kv_seq_len,
|
bsz, q_len, kv_seq_len,
|
||||||
self.head_dim, self.num_heads)
|
self.head_dim, self.num_heads, output_attentions)
|
||||||
else:
|
else:
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
|
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
|
||||||
|
|
@ -642,8 +651,7 @@ def llama_attention_forward_4_31_original(
|
||||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
bsz, q_len, kv_seq_len,
|
bsz, q_len, kv_seq_len,
|
||||||
self.head_dim, self.num_heads)
|
self.head_dim, self.num_heads, output_attentions)
|
||||||
|
|
||||||
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
||||||
if attn_output.size() != attn_output_size:
|
if attn_output.size() != attn_output_size:
|
||||||
invalidInputError(False,
|
invalidInputError(False,
|
||||||
|
|
@ -814,7 +822,8 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
1,
|
1,
|
||||||
current_kv_len,
|
current_kv_len,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.num_heads)
|
self.num_heads,
|
||||||
|
output_attentions)
|
||||||
if attn_output.size() != (1, self.num_heads, 1, self.head_dim):
|
if attn_output.size() != (1, self.num_heads, 1, self.head_dim):
|
||||||
invalidInputError(False,
|
invalidInputError(False,
|
||||||
f"`attn_output` should be of size "
|
f"`attn_output` should be of size "
|
||||||
|
|
@ -858,7 +867,8 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
q_len,
|
q_len,
|
||||||
kv_seq_len,
|
kv_seq_len,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.num_heads)
|
self.num_heads,
|
||||||
|
output_attentions)
|
||||||
|
|
||||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
invalidInputError(False,
|
invalidInputError(False,
|
||||||
|
|
@ -1291,7 +1301,7 @@ def llama_attention_forward_4_36_original(
|
||||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
bsz, q_len, kv_seq_len,
|
bsz, q_len, kv_seq_len,
|
||||||
self.head_dim, self.num_heads)
|
self.head_dim, self.num_heads, output_attentions)
|
||||||
|
|
||||||
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
||||||
if attn_output.size() != attn_output_size:
|
if attn_output.size() != attn_output_size:
|
||||||
|
|
@ -1318,33 +1328,65 @@ def llama_attention_forward_4_36_original(
|
||||||
|
|
||||||
|
|
||||||
def native_sdp(query, key, value, attention_mask,
|
def native_sdp(query, key, value, attention_mask,
|
||||||
bsz, q_len, kv_seq_len, head_dim, num_heads):
|
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
|
||||||
attn_weights = torch.matmul(query.to(key.dtype),
|
if should_split_qkv_tensor(query, output_attentions):
|
||||||
key.transpose(2, 3)) / math.sqrt(head_dim)
|
return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
||||||
|
bsz, q_len, kv_seq_len, head_dim)
|
||||||
attn_weights_size = (bsz, num_heads, q_len, kv_seq_len)
|
|
||||||
if attn_weights.size() != attn_weights_size:
|
|
||||||
invalidInputError(False,
|
|
||||||
f"Attention weights should be of size {attn_weights_size}, "
|
|
||||||
f"but is {attn_weights.size()}")
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
|
||||||
if attention_mask.size() != attn_mask_size:
|
|
||||||
invalidInputError(False,
|
|
||||||
f"Attention mask should be of size {attn_mask_size}, "
|
|
||||||
f"but is {attention_mask.size()}")
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
if kv_seq_len >= 2048:
|
|
||||||
# for memory considerations, do not upcast attention to fp32 for long sequences
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
||||||
else:
|
else:
|
||||||
# upcast attention to fp32
|
attn_weights = torch.matmul(query.to(key.dtype),
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
key.transpose(2, 3)) / math.sqrt(head_dim)
|
||||||
dtype=torch.float32).to(value.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value)
|
attn_weights_size = (bsz, num_heads, q_len, kv_seq_len)
|
||||||
return attn_output, attn_weights
|
if attn_weights.size() != attn_weights_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Attention weights should be of size {attn_weights_size}, "
|
||||||
|
f"but is {attn_weights.size()}")
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||||
|
if attention_mask.size() != attn_mask_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Attention mask should be of size {attn_mask_size}, "
|
||||||
|
f"but is {attention_mask.size()}")
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
if kv_seq_len >= 2048:
|
||||||
|
# for memory considerations, do not upcast attention to fp32 for long sequences
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
else:
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
|
dtype=torch.float32).to(value.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
||||||
|
bsz, q_len, kv_seq_len, head_dim):
|
||||||
|
query_split = torch.split(query.to(key.dtype), 16, dim=1)
|
||||||
|
key_split = torch.split(key.transpose(2, 3), 16, dim=1)
|
||||||
|
value_split = torch.split(value, 16, dim=1)
|
||||||
|
attn_outputs = []
|
||||||
|
for q, k, v in zip(query_split, key_split, value_split):
|
||||||
|
attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim)
|
||||||
|
attn_weights_split_size = (bsz, 16, q_len, kv_seq_len)
|
||||||
|
if attn_weights_split.size() != attn_weights_split_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Splitted attention weights should be of size "
|
||||||
|
f"{attn_weights_split_size}, but is {attn_weights_split.size()}")
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||||
|
if attention_mask.size() != attn_mask_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Attention mask should be of size {attn_mask_size}, "
|
||||||
|
f"but is {attention_mask.size()}")
|
||||||
|
attn_weights_split = attn_weights_split + attention_mask
|
||||||
|
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
|
||||||
|
attn_weights_split = torch.matmul(attn_weights_split, v)
|
||||||
|
attn_outputs.append(attn_weights_split)
|
||||||
|
attn_output = torch.cat(attn_outputs, dim=1)
|
||||||
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
def llama_model_selective_batching_forward_4_31(
|
def llama_model_selective_batching_forward_4_31(
|
||||||
|
|
@ -1601,7 +1643,7 @@ def llama_attention_fast_forward(
|
||||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
bsz, q_len, kv_seq_len,
|
bsz, q_len, kv_seq_len,
|
||||||
self.head_dim, self.num_heads)
|
self.head_dim, self.num_heads, output_attentions)
|
||||||
|
|
||||||
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
||||||
if attn_output.size() != attn_output_size:
|
if attn_output.size() != attn_output_size:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue