LLM: support llama2 8k input with w4a16. (#10677)

* LLM: support llama2 8k input with w4a16.

* fix comment and style.

* fix style.

* fix comments and split tensor to quantized attention forward.

* fix style.

* refactor name.

* fix style.

* fix style.

* fix style.

* refactor checker name.

* refactor native sdp split qkv tensor name.

* fix style.

* fix comment rename variables.

* fix co-exist of intermedia results.
This commit is contained in:
Cengguang Zhang 2024-04-08 11:43:15 +08:00 committed by GitHub
parent db7c5cb78f
commit c0cd238e40
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -214,6 +214,15 @@ def should_use_fast_rope(self, query_states, position_ids):
return use_fuse_rope
def should_split_qkv_tensor(query_states, output_attentions):
if not output_attentions and query_states.dtype == torch.float16 and \
query_states.shape[2] >= 6800:
# split tensor for memory block limitation
# support fp16 and set input length threshold at 6800 for now
return True
return False
def llama_decoder_forward(
self,
hidden_states: torch.Tensor,
@ -404,7 +413,7 @@ def llama_attention_forward_4_31_quantized(
attn_output, attn_weights = native_sdp(query_states, repeated_key_states,
repeated_value_states, attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads)
self.head_dim, self.num_heads, output_attentions)
if use_cache:
k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_key_value_heads, kv_seq_len, self.head_dim,
@ -429,7 +438,7 @@ def llama_attention_forward_4_31_quantized(
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads)
self.head_dim, self.num_heads, output_attentions)
else:
import linear_q4_0
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
@ -642,8 +651,7 @@ def llama_attention_forward_4_31_original(
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads)
self.head_dim, self.num_heads, output_attentions)
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size:
invalidInputError(False,
@ -814,7 +822,8 @@ def llama_attention_selective_batching_forward_4_31(
1,
current_kv_len,
self.head_dim,
self.num_heads)
self.num_heads,
output_attentions)
if attn_output.size() != (1, self.num_heads, 1, self.head_dim):
invalidInputError(False,
f"`attn_output` should be of size "
@ -858,7 +867,8 @@ def llama_attention_selective_batching_forward_4_31(
q_len,
kv_seq_len,
self.head_dim,
self.num_heads)
self.num_heads,
output_attentions)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(False,
@ -1291,7 +1301,7 @@ def llama_attention_forward_4_36_original(
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads)
self.head_dim, self.num_heads, output_attentions)
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size:
@ -1318,7 +1328,11 @@ def llama_attention_forward_4_36_original(
def native_sdp(query, key, value, attention_mask,
bsz, q_len, kv_seq_len, head_dim, num_heads):
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
if should_split_qkv_tensor(query, output_attentions):
return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
bsz, q_len, kv_seq_len, head_dim)
else:
attn_weights = torch.matmul(query.to(key.dtype),
key.transpose(2, 3)) / math.sqrt(head_dim)
@ -1347,6 +1361,34 @@ def native_sdp(query, key, value, attention_mask,
return attn_output, attn_weights
def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
bsz, q_len, kv_seq_len, head_dim):
query_split = torch.split(query.to(key.dtype), 16, dim=1)
key_split = torch.split(key.transpose(2, 3), 16, dim=1)
value_split = torch.split(value, 16, dim=1)
attn_outputs = []
for q, k, v in zip(query_split, key_split, value_split):
attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim)
attn_weights_split_size = (bsz, 16, q_len, kv_seq_len)
if attn_weights_split.size() != attn_weights_split_size:
invalidInputError(False,
f"Splitted attention weights should be of size "
f"{attn_weights_split_size}, but is {attn_weights_split.size()}")
if attention_mask is not None:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights_split = attn_weights_split + attention_mask
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
attn_weights_split = torch.matmul(attn_weights_split, v)
attn_outputs.append(attn_weights_split)
attn_output = torch.cat(attn_outputs, dim=1)
return attn_output, None
def llama_model_selective_batching_forward_4_31(
self,
input_ids: torch.LongTensor = None,
@ -1601,7 +1643,7 @@ def llama_attention_fast_forward(
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads)
self.head_dim, self.num_heads, output_attentions)
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size: