optimize siglip attention on arc (#12569)
This commit is contained in:
		
							parent
							
								
									1a2ab12876
								
							
						
					
					
						commit
						a4eb561f36
					
				
					 1 changed files with 27 additions and 16 deletions
				
			
		| 
						 | 
				
			
			@ -53,28 +53,39 @@ def siglip_attention_forward(
 | 
			
		|||
    qkv = qkv.transpose(1, 2)
 | 
			
		||||
    query_states, key_states, value_states = qkv.chunk(3, dim=1)
 | 
			
		||||
 | 
			
		||||
    query_states, key_states, value_states = padding_qkv_hd(
 | 
			
		||||
        query_states, key_states, value_states,
 | 
			
		||||
        72, 80
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
 | 
			
		||||
    from ipex_llm.transformers.utils import get_xpu_device_type
 | 
			
		||||
    if (
 | 
			
		||||
        self.head_dim == 72
 | 
			
		||||
        and get_xpu_device_type(query_states) in ["arc", "flex"] and
 | 
			
		||||
        query_states.dtype in [torch.float, torch.half]
 | 
			
		||||
    ):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
        attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
 | 
			
		||||
                                               value_states.contiguous(), attention_mask)
 | 
			
		||||
        attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
 | 
			
		||||
                                                      value_states, attention_softmax)
 | 
			
		||||
    else:
 | 
			
		||||
        attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
        query_states, key_states, value_states = padding_qkv_hd(
 | 
			
		||||
            query_states, key_states, value_states,
 | 
			
		||||
            72, 80
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
        if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
 | 
			
		||||
            import xe_addons
 | 
			
		||||
            attn_weights = None
 | 
			
		||||
            attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
 | 
			
		||||
                                                   value_states.contiguous(), attention_mask)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
        attn_weights = torch.nn.functional.dropout(attn_weights,
 | 
			
		||||
                                                   p=self.dropout, training=self.training)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
            attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output[:, :, :, :self.head_dim]
 | 
			
		||||
            attn_weights = torch.nn.functional.dropout(attn_weights,
 | 
			
		||||
                                                       p=self.dropout, training=self.training)
 | 
			
		||||
            attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
        attn_output = attn_output[:, :, :, :self.head_dim]
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue