parent
							
								
									6be70283b7
								
							
						
					
					
						commit
						16b2a418be
					
				
					 1 changed files with 9 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -855,11 +855,13 @@ def llama_attention_selective_batching_forward_4_31(
 | 
			
		|||
                current_key_states = repeat_kv(current_key_states, self.num_key_value_groups)
 | 
			
		||||
                current_value_states = repeat_kv(current_value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
                cache_position = None
 | 
			
		||||
                current_query_states = query_states[batch: batch + 1, :, :, :]
 | 
			
		||||
                attn_output, attn_weights = native_sdp(current_query_states,
 | 
			
		||||
                                                       current_key_states,
 | 
			
		||||
                                                       current_value_states,
 | 
			
		||||
                                                       attention_mask[batch],
 | 
			
		||||
                                                       cache_position,
 | 
			
		||||
                                                       1,
 | 
			
		||||
                                                       1,
 | 
			
		||||
                                                       current_kv_len,
 | 
			
		||||
| 
						 | 
				
			
			@ -901,10 +903,12 @@ def llama_attention_selective_batching_forward_4_31(
 | 
			
		|||
    if isinstance(attention_mask, list):
 | 
			
		||||
        # For decoding fast path
 | 
			
		||||
        attention_mask = attention_mask[0]
 | 
			
		||||
    cache_position = None
 | 
			
		||||
    attn_output, attn_weights = native_sdp(query_states,
 | 
			
		||||
                                           key_states,
 | 
			
		||||
                                           value_states,
 | 
			
		||||
                                           attention_mask,
 | 
			
		||||
                                           cache_position,
 | 
			
		||||
                                           bsz,
 | 
			
		||||
                                           q_len,
 | 
			
		||||
                                           kv_seq_len,
 | 
			
		||||
| 
						 | 
				
			
			@ -1445,7 +1449,7 @@ def llama_attention_forward_4_38_original(
 | 
			
		|||
def native_sdp(query, key, value, attention_mask, cache_position,
 | 
			
		||||
               bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
 | 
			
		||||
    if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions):
 | 
			
		||||
        return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
 | 
			
		||||
        return native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_position,
 | 
			
		||||
                                           bsz, q_len, kv_seq_len, head_dim, num_heads)
 | 
			
		||||
    else:
 | 
			
		||||
        attn_weights = torch.matmul(query.to(key.dtype),
 | 
			
		||||
| 
						 | 
				
			
			@ -1502,14 +1506,14 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_positio
 | 
			
		|||
            if cache_position is not None:
 | 
			
		||||
                # for transformers 4.38.0
 | 
			
		||||
                causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
 | 
			
		||||
                attn_weights = attn_weights + causal_mask
 | 
			
		||||
                attn_weights_split = attn_weights_split + causal_mask
 | 
			
		||||
            else:
 | 
			
		||||
                attn_mask_size = (bsz, 1, q_len, kv_seq_len)
 | 
			
		||||
                if attention_mask.size() != attn_mask_size:
 | 
			
		||||
                    invalidInputError(False,
 | 
			
		||||
                                      f"Attention mask should be of size {attn_mask_size}, "
 | 
			
		||||
                                      f"but is {attention_mask.size()}")
 | 
			
		||||
                attn_weights = attn_weights + attention_mask
 | 
			
		||||
                attn_weights_split = attn_weights_split + attention_mask
 | 
			
		||||
        attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
 | 
			
		||||
        attn_outputs.append(torch.matmul(attn_weights_split, v))
 | 
			
		||||
    attn_output = torch.cat(attn_outputs, dim=1)
 | 
			
		||||
| 
						 | 
				
			
			@ -1767,8 +1771,9 @@ def llama_attention_fast_forward(
 | 
			
		|||
    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
    cache_position = None
 | 
			
		||||
    attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
 | 
			
		||||
                                           attention_mask,
 | 
			
		||||
                                           attention_mask, cache_position,
 | 
			
		||||
                                           bsz, q_len, kv_seq_len,
 | 
			
		||||
                                           self.head_dim, self.num_heads, output_attentions)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue