LLM: support llama2 8k input with w4a16. (#10677)
* LLM: support llama2 8k input with w4a16. * fix comment and style. * fix style. * fix comments and split tensor to quantized attention forward. * fix style. * refactor name. * fix style. * fix style. * fix style. * refactor checker name. * refactor native sdp split qkv tensor name. * fix style. * fix comment rename variables. * fix co-exist of intermedia results.
This commit is contained in:
		
							parent
							
								
									db7c5cb78f
								
							
						
					
					
						commit
						c0cd238e40
					
				
					 1 changed files with 76 additions and 34 deletions
				
			
		| 
						 | 
				
			
			@ -214,6 +214,15 @@ def should_use_fast_rope(self, query_states, position_ids):
 | 
			
		|||
    return use_fuse_rope
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_split_qkv_tensor(query_states, output_attentions):
 | 
			
		||||
    if not output_attentions and query_states.dtype == torch.float16 and \
 | 
			
		||||
            query_states.shape[2] >= 6800:
 | 
			
		||||
        # split tensor for memory block limitation
 | 
			
		||||
        # support fp16 and set input length threshold at 6800 for now
 | 
			
		||||
        return True
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama_decoder_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			@ -404,7 +413,7 @@ def llama_attention_forward_4_31_quantized(
 | 
			
		|||
        attn_output, attn_weights = native_sdp(query_states, repeated_key_states,
 | 
			
		||||
                                               repeated_value_states, attention_mask,
 | 
			
		||||
                                               bsz, q_len, kv_seq_len,
 | 
			
		||||
                                               self.head_dim, self.num_heads)
 | 
			
		||||
                                               self.head_dim, self.num_heads, output_attentions)
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            k_cache, v_cache = init_fp8_kv_cache(
 | 
			
		||||
                bsz, self.num_key_value_heads, kv_seq_len, self.head_dim,
 | 
			
		||||
| 
						 | 
				
			
			@ -429,7 +438,7 @@ def llama_attention_forward_4_31_quantized(
 | 
			
		|||
            attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
 | 
			
		||||
                                                   attention_mask,
 | 
			
		||||
                                                   bsz, q_len, kv_seq_len,
 | 
			
		||||
                                                   self.head_dim, self.num_heads)
 | 
			
		||||
                                                   self.head_dim, self.num_heads, output_attentions)
 | 
			
		||||
        else:
 | 
			
		||||
            import linear_q4_0
 | 
			
		||||
            attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -642,8 +651,7 @@ def llama_attention_forward_4_31_original(
 | 
			
		|||
        attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
 | 
			
		||||
                                               attention_mask,
 | 
			
		||||
                                               bsz, q_len, kv_seq_len,
 | 
			
		||||
                                               self.head_dim, self.num_heads)
 | 
			
		||||
 | 
			
		||||
                                               self.head_dim, self.num_heads, output_attentions)
 | 
			
		||||
    attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
 | 
			
		||||
    if attn_output.size() != attn_output_size:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
| 
						 | 
				
			
			@ -814,7 +822,8 @@ def llama_attention_selective_batching_forward_4_31(
 | 
			
		|||
                                                       1,
 | 
			
		||||
                                                       current_kv_len,
 | 
			
		||||
                                                       self.head_dim,
 | 
			
		||||
                                                       self.num_heads)
 | 
			
		||||
                                                       self.num_heads,
 | 
			
		||||
                                                       output_attentions)
 | 
			
		||||
                if attn_output.size() != (1, self.num_heads, 1, self.head_dim):
 | 
			
		||||
                    invalidInputError(False,
 | 
			
		||||
                                      f"`attn_output` should be of size "
 | 
			
		||||
| 
						 | 
				
			
			@ -858,7 +867,8 @@ def llama_attention_selective_batching_forward_4_31(
 | 
			
		|||
                                           q_len,
 | 
			
		||||
                                           kv_seq_len,
 | 
			
		||||
                                           self.head_dim,
 | 
			
		||||
                                           self.num_heads)
 | 
			
		||||
                                           self.num_heads,
 | 
			
		||||
                                           output_attentions)
 | 
			
		||||
 | 
			
		||||
    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
| 
						 | 
				
			
			@ -1291,7 +1301,7 @@ def llama_attention_forward_4_36_original(
 | 
			
		|||
        attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
 | 
			
		||||
                                               attention_mask,
 | 
			
		||||
                                               bsz, q_len, kv_seq_len,
 | 
			
		||||
                                               self.head_dim, self.num_heads)
 | 
			
		||||
                                               self.head_dim, self.num_heads, output_attentions)
 | 
			
		||||
 | 
			
		||||
    attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
 | 
			
		||||
    if attn_output.size() != attn_output_size:
 | 
			
		||||
| 
						 | 
				
			
			@ -1318,7 +1328,11 @@ def llama_attention_forward_4_36_original(
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def native_sdp(query, key, value, attention_mask,
 | 
			
		||||
               bsz, q_len, kv_seq_len, head_dim, num_heads):
 | 
			
		||||
               bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
 | 
			
		||||
    if should_split_qkv_tensor(query, output_attentions):
 | 
			
		||||
        return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
 | 
			
		||||
                                           bsz, q_len, kv_seq_len, head_dim)
 | 
			
		||||
    else:
 | 
			
		||||
        attn_weights = torch.matmul(query.to(key.dtype),
 | 
			
		||||
                                    key.transpose(2, 3)) / math.sqrt(head_dim)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1347,6 +1361,34 @@ def native_sdp(query, key, value, attention_mask,
 | 
			
		|||
        return attn_output, attn_weights
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
 | 
			
		||||
                                bsz, q_len, kv_seq_len, head_dim):
 | 
			
		||||
    query_split = torch.split(query.to(key.dtype), 16, dim=1)
 | 
			
		||||
    key_split = torch.split(key.transpose(2, 3), 16, dim=1)
 | 
			
		||||
    value_split = torch.split(value, 16, dim=1)
 | 
			
		||||
    attn_outputs = []
 | 
			
		||||
    for q, k, v in zip(query_split, key_split, value_split):
 | 
			
		||||
        attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim)
 | 
			
		||||
        attn_weights_split_size = (bsz, 16, q_len, kv_seq_len)
 | 
			
		||||
        if attn_weights_split.size() != attn_weights_split_size:
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              f"Splitted attention weights should be of size "
 | 
			
		||||
                              f"{attn_weights_split_size}, but is {attn_weights_split.size()}")
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_mask_size = (bsz, 1, q_len, kv_seq_len)
 | 
			
		||||
            if attention_mask.size() != attn_mask_size:
 | 
			
		||||
                invalidInputError(False,
 | 
			
		||||
                                  f"Attention mask should be of size {attn_mask_size}, "
 | 
			
		||||
                                  f"but is {attention_mask.size()}")
 | 
			
		||||
            attn_weights_split = attn_weights_split + attention_mask
 | 
			
		||||
        attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
 | 
			
		||||
        attn_weights_split = torch.matmul(attn_weights_split, v)
 | 
			
		||||
        attn_outputs.append(attn_weights_split)
 | 
			
		||||
    attn_output = torch.cat(attn_outputs, dim=1)
 | 
			
		||||
    return attn_output, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama_model_selective_batching_forward_4_31(
 | 
			
		||||
    self,
 | 
			
		||||
    input_ids: torch.LongTensor = None,
 | 
			
		||||
| 
						 | 
				
			
			@ -1601,7 +1643,7 @@ def llama_attention_fast_forward(
 | 
			
		|||
    attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
 | 
			
		||||
                                           attention_mask,
 | 
			
		||||
                                           bsz, q_len, kv_seq_len,
 | 
			
		||||
                                           self.head_dim, self.num_heads)
 | 
			
		||||
                                           self.head_dim, self.num_heads, output_attentions)
 | 
			
		||||
 | 
			
		||||
    attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
 | 
			
		||||
    if attn_output.size() != attn_output_size:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue