optimize siglip attention again (#12578)
This commit is contained in:
		
							parent
							
								
									e0921f80c1
								
							
						
					
					
						commit
						4540424271
					
				
					 3 changed files with 41 additions and 5 deletions
				
			
		| 
						 | 
					@ -237,27 +237,50 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value:
 | 
				
			||||||
        mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
 | 
					        mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # compute
 | 
					        # compute
 | 
				
			||||||
 | 
					        # import xe_addons
 | 
				
			||||||
 | 
					        # if is_causal:
 | 
				
			||||||
 | 
					        #     if key.dtype == torch.uint8:
 | 
				
			||||||
 | 
					        #         attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
 | 
				
			||||||
 | 
					        #     else:
 | 
				
			||||||
 | 
					        #         attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
 | 
				
			||||||
 | 
					        # elif seq_length != kv_length and seq_length <= 32:
 | 
				
			||||||
 | 
					        #     # todo: add scale support
 | 
				
			||||||
 | 
					        #     if key.dtype == torch.uint8:
 | 
				
			||||||
 | 
					        #         attn_output = xe_addons.sdp_fp8(query, key, value, mask)
 | 
				
			||||||
 | 
					        #     else:
 | 
				
			||||||
 | 
					        #         attn_output = xe_addons.sdp(query, key, value, mask)
 | 
				
			||||||
 | 
					        # else:
 | 
				
			||||||
 | 
					        #     if key.dtype == torch.uint8:
 | 
				
			||||||
 | 
					        #         attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
 | 
				
			||||||
 | 
					        #     else:
 | 
				
			||||||
 | 
					        #         attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        import xe_addons
 | 
					        import xe_addons
 | 
				
			||||||
        if is_causal:
 | 
					        if is_causal:
 | 
				
			||||||
            if key.dtype == torch.uint8:
 | 
					            if key.dtype == torch.uint8:
 | 
				
			||||||
                attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
 | 
					                attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
 | 
					                attn_output = xe_addons.sdp_causal(query, key, value, mask)
 | 
				
			||||||
        elif seq_length != kv_length and seq_length <= 32:
 | 
					        elif seq_length != kv_length and seq_length <= 32:
 | 
				
			||||||
            # todo: add scale support
 | 
					 | 
				
			||||||
            if key.dtype == torch.uint8:
 | 
					            if key.dtype == torch.uint8:
 | 
				
			||||||
                attn_output = xe_addons.sdp_fp8(query, key, value, mask)
 | 
					                attn_output = xe_addons.sdp_fp8(query, key, value, mask)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                attn_output = xe_addons.sdp(query, key, value, mask)
 | 
					                attn_output = xe_addons.sdp(query, key, value, mask)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if key.dtype == torch.uint8:
 | 
					            if key.dtype == torch.uint8:
 | 
				
			||||||
                attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
 | 
					                attn_output = xe_addons.sdp_fp8(query, key, value, mask)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
 | 
					                attn_output = xe_addons.sdp_non_causal(query, key, value, mask)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return attn_output
 | 
					        return attn_output
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        mask = mask[..., :seq_length, :kv_length] if mask is not None else None
 | 
					        mask = mask[..., :seq_length, :kv_length] if mask is not None else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.utils import repeat_kv
 | 
				
			||||||
 | 
					        if n_heads != n_kv_heads:
 | 
				
			||||||
 | 
					            key = repeat_kv(key, n_heads // n_kv_heads)
 | 
				
			||||||
 | 
					            value = repeat_kv(value, n_heads // n_kv_heads)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return torch.nn.functional.scaled_dot_product_attention(
 | 
					        return torch.nn.functional.scaled_dot_product_attention(
 | 
				
			||||||
            query, key, value, mask, is_causal=is_causal, scale=scale
 | 
					            query, key, value, mask, is_causal=is_causal, scale=scale
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -59,6 +59,10 @@ def siglip_attention_forward(
 | 
				
			||||||
        and get_xpu_device_type(query_states) in ["arc", "flex"] and
 | 
					        and get_xpu_device_type(query_states) in ["arc", "flex"] and
 | 
				
			||||||
        query_states.dtype in [torch.float, torch.half]
 | 
					        query_states.dtype in [torch.float, torch.half]
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
 | 
					        n_heads, kv_length = query_states.size(1), key_states.size(2)
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.common import prepare_mask
 | 
				
			||||||
 | 
					        attention_mask = prepare_mask(attention_mask, bsz, n_heads, q_len, kv_length,
 | 
				
			||||||
 | 
					                                      False, query_states.dtype, query_states.device)
 | 
				
			||||||
        import xe_addons
 | 
					        import xe_addons
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
        attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
 | 
					        attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -388,6 +388,15 @@ def fp16_fusion_check(proj, x, training):
 | 
				
			||||||
    return True
 | 
					    return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
				
			||||||
 | 
					    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
 | 
				
			||||||
 | 
					    if n_rep == 1:
 | 
				
			||||||
 | 
					        return hidden_states
 | 
				
			||||||
 | 
					    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
 | 
				
			||||||
 | 
					                                                           n_rep, slen, head_dim)
 | 
				
			||||||
 | 
					    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def update_past_key_value(past_key_value, key_states, value_states,
 | 
					def update_past_key_value(past_key_value, key_states, value_states,
 | 
				
			||||||
                          kv_seq_len, use_quantize_kv, device):
 | 
					                          kv_seq_len, use_quantize_kv, device):
 | 
				
			||||||
    bsz, num_heads, _, head_dim = key_states.shape
 | 
					    bsz, num_heads, _, head_dim = key_states.shape
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue