optimize siglip attention again (#12578)
This commit is contained in:
		
							parent
							
								
									e0921f80c1
								
							
						
					
					
						commit
						4540424271
					
				
					 3 changed files with 41 additions and 5 deletions
				
			
		| 
						 | 
				
			
			@ -237,27 +237,50 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value:
 | 
			
		|||
        mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
 | 
			
		||||
 | 
			
		||||
        # compute
 | 
			
		||||
        # import xe_addons
 | 
			
		||||
        # if is_causal:
 | 
			
		||||
        #     if key.dtype == torch.uint8:
 | 
			
		||||
        #         attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
 | 
			
		||||
        #     else:
 | 
			
		||||
        #         attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
 | 
			
		||||
        # elif seq_length != kv_length and seq_length <= 32:
 | 
			
		||||
        #     # todo: add scale support
 | 
			
		||||
        #     if key.dtype == torch.uint8:
 | 
			
		||||
        #         attn_output = xe_addons.sdp_fp8(query, key, value, mask)
 | 
			
		||||
        #     else:
 | 
			
		||||
        #         attn_output = xe_addons.sdp(query, key, value, mask)
 | 
			
		||||
        # else:
 | 
			
		||||
        #     if key.dtype == torch.uint8:
 | 
			
		||||
        #         attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
 | 
			
		||||
        #     else:
 | 
			
		||||
        #         attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
 | 
			
		||||
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if is_causal:
 | 
			
		||||
            if key.dtype == torch.uint8:
 | 
			
		||||
                attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
 | 
			
		||||
                attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask)
 | 
			
		||||
            else:
 | 
			
		||||
                attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
 | 
			
		||||
                attn_output = xe_addons.sdp_causal(query, key, value, mask)
 | 
			
		||||
        elif seq_length != kv_length and seq_length <= 32:
 | 
			
		||||
            # todo: add scale support
 | 
			
		||||
            if key.dtype == torch.uint8:
 | 
			
		||||
                attn_output = xe_addons.sdp_fp8(query, key, value, mask)
 | 
			
		||||
            else:
 | 
			
		||||
                attn_output = xe_addons.sdp(query, key, value, mask)
 | 
			
		||||
        else:
 | 
			
		||||
            if key.dtype == torch.uint8:
 | 
			
		||||
                attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
 | 
			
		||||
                attn_output = xe_addons.sdp_fp8(query, key, value, mask)
 | 
			
		||||
            else:
 | 
			
		||||
                attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
 | 
			
		||||
                attn_output = xe_addons.sdp_non_causal(query, key, value, mask)
 | 
			
		||||
 | 
			
		||||
        return attn_output
 | 
			
		||||
    else:
 | 
			
		||||
        mask = mask[..., :seq_length, :kv_length] if mask is not None else None
 | 
			
		||||
 | 
			
		||||
        from ipex_llm.transformers.models.utils import repeat_kv
 | 
			
		||||
        if n_heads != n_kv_heads:
 | 
			
		||||
            key = repeat_kv(key, n_heads // n_kv_heads)
 | 
			
		||||
            value = repeat_kv(value, n_heads // n_kv_heads)
 | 
			
		||||
 | 
			
		||||
        return torch.nn.functional.scaled_dot_product_attention(
 | 
			
		||||
            query, key, value, mask, is_causal=is_causal, scale=scale
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -59,6 +59,10 @@ def siglip_attention_forward(
 | 
			
		|||
        and get_xpu_device_type(query_states) in ["arc", "flex"] and
 | 
			
		||||
        query_states.dtype in [torch.float, torch.half]
 | 
			
		||||
    ):
 | 
			
		||||
        n_heads, kv_length = query_states.size(1), key_states.size(2)
 | 
			
		||||
        from ipex_llm.transformers.models.common import prepare_mask
 | 
			
		||||
        attention_mask = prepare_mask(attention_mask, bsz, n_heads, q_len, kv_length,
 | 
			
		||||
                                      False, query_states.dtype, query_states.device)
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
        attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -388,6 +388,15 @@ def fp16_fusion_check(proj, x, training):
 | 
			
		|||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		||||
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
 | 
			
		||||
    if n_rep == 1:
 | 
			
		||||
        return hidden_states
 | 
			
		||||
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
 | 
			
		||||
                                                           n_rep, slen, head_dim)
 | 
			
		||||
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def update_past_key_value(past_key_value, key_states, value_states,
 | 
			
		||||
                          kv_seq_len, use_quantize_kv, device):
 | 
			
		||||
    bsz, num_heads, _, head_dim = key_states.shape
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue