LLM: support llama2 8k input with w4a16. (#10677)
* LLM: support llama2 8k input with w4a16. * fix comment and style. * fix style. * fix comments and split tensor to quantized attention forward. * fix style. * refactor name. * fix style. * fix style. * fix style. * refactor checker name. * refactor native sdp split qkv tensor name. * fix style. * fix comment rename variables. * fix co-exist of intermedia results.
This commit is contained in:
parent
db7c5cb78f
commit
c0cd238e40
1 changed files with 76 additions and 34 deletions
|
|
@ -214,6 +214,15 @@ def should_use_fast_rope(self, query_states, position_ids):
|
|||
return use_fuse_rope
|
||||
|
||||
|
||||
def should_split_qkv_tensor(query_states, output_attentions):
|
||||
if not output_attentions and query_states.dtype == torch.float16 and \
|
||||
query_states.shape[2] >= 6800:
|
||||
# split tensor for memory block limitation
|
||||
# support fp16 and set input length threshold at 6800 for now
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def llama_decoder_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
@ -404,7 +413,7 @@ def llama_attention_forward_4_31_quantized(
|
|||
attn_output, attn_weights = native_sdp(query_states, repeated_key_states,
|
||||
repeated_value_states, attention_mask,
|
||||
bsz, q_len, kv_seq_len,
|
||||
self.head_dim, self.num_heads)
|
||||
self.head_dim, self.num_heads, output_attentions)
|
||||
if use_cache:
|
||||
k_cache, v_cache = init_fp8_kv_cache(
|
||||
bsz, self.num_key_value_heads, kv_seq_len, self.head_dim,
|
||||
|
|
@ -429,7 +438,7 @@ def llama_attention_forward_4_31_quantized(
|
|||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||
attention_mask,
|
||||
bsz, q_len, kv_seq_len,
|
||||
self.head_dim, self.num_heads)
|
||||
self.head_dim, self.num_heads, output_attentions)
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
|
||||
|
|
@ -642,8 +651,7 @@ def llama_attention_forward_4_31_original(
|
|||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||
attention_mask,
|
||||
bsz, q_len, kv_seq_len,
|
||||
self.head_dim, self.num_heads)
|
||||
|
||||
self.head_dim, self.num_heads, output_attentions)
|
||||
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
||||
if attn_output.size() != attn_output_size:
|
||||
invalidInputError(False,
|
||||
|
|
@ -814,7 +822,8 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
1,
|
||||
current_kv_len,
|
||||
self.head_dim,
|
||||
self.num_heads)
|
||||
self.num_heads,
|
||||
output_attentions)
|
||||
if attn_output.size() != (1, self.num_heads, 1, self.head_dim):
|
||||
invalidInputError(False,
|
||||
f"`attn_output` should be of size "
|
||||
|
|
@ -858,7 +867,8 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
q_len,
|
||||
kv_seq_len,
|
||||
self.head_dim,
|
||||
self.num_heads)
|
||||
self.num_heads,
|
||||
output_attentions)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
invalidInputError(False,
|
||||
|
|
@ -1291,7 +1301,7 @@ def llama_attention_forward_4_36_original(
|
|||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||
attention_mask,
|
||||
bsz, q_len, kv_seq_len,
|
||||
self.head_dim, self.num_heads)
|
||||
self.head_dim, self.num_heads, output_attentions)
|
||||
|
||||
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
||||
if attn_output.size() != attn_output_size:
|
||||
|
|
@ -1318,7 +1328,11 @@ def llama_attention_forward_4_36_original(
|
|||
|
||||
|
||||
def native_sdp(query, key, value, attention_mask,
|
||||
bsz, q_len, kv_seq_len, head_dim, num_heads):
|
||||
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
|
||||
if should_split_qkv_tensor(query, output_attentions):
|
||||
return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
||||
bsz, q_len, kv_seq_len, head_dim)
|
||||
else:
|
||||
attn_weights = torch.matmul(query.to(key.dtype),
|
||||
key.transpose(2, 3)) / math.sqrt(head_dim)
|
||||
|
||||
|
|
@ -1347,6 +1361,34 @@ def native_sdp(query, key, value, attention_mask,
|
|||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
||||
bsz, q_len, kv_seq_len, head_dim):
|
||||
query_split = torch.split(query.to(key.dtype), 16, dim=1)
|
||||
key_split = torch.split(key.transpose(2, 3), 16, dim=1)
|
||||
value_split = torch.split(value, 16, dim=1)
|
||||
attn_outputs = []
|
||||
for q, k, v in zip(query_split, key_split, value_split):
|
||||
attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim)
|
||||
attn_weights_split_size = (bsz, 16, q_len, kv_seq_len)
|
||||
if attn_weights_split.size() != attn_weights_split_size:
|
||||
invalidInputError(False,
|
||||
f"Splitted attention weights should be of size "
|
||||
f"{attn_weights_split_size}, but is {attn_weights_split.size()}")
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||
if attention_mask.size() != attn_mask_size:
|
||||
invalidInputError(False,
|
||||
f"Attention mask should be of size {attn_mask_size}, "
|
||||
f"but is {attention_mask.size()}")
|
||||
attn_weights_split = attn_weights_split + attention_mask
|
||||
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
|
||||
attn_weights_split = torch.matmul(attn_weights_split, v)
|
||||
attn_outputs.append(attn_weights_split)
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
return attn_output, None
|
||||
|
||||
|
||||
def llama_model_selective_batching_forward_4_31(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
|
@ -1601,7 +1643,7 @@ def llama_attention_fast_forward(
|
|||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||
attention_mask,
|
||||
bsz, q_len, kv_seq_len,
|
||||
self.head_dim, self.num_heads)
|
||||
self.head_dim, self.num_heads, output_attentions)
|
||||
|
||||
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
||||
if attn_output.size() != attn_output_size:
|
||||
|
|
|
|||
Loading…
Reference in a new issue