support passing attn_scale to sdpa (#12619)
This commit is contained in:
		
							parent
							
								
									40a7d2b4f0
								
							
						
					
					
						commit
						a9abde0b5d
					
				
					 1 changed files with 5 additions and 22 deletions
				
			
		| 
						 | 
				
			
			@ -237,40 +237,23 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
 | 
			
		|||
        mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
 | 
			
		||||
 | 
			
		||||
        # compute
 | 
			
		||||
        # import xe_addons
 | 
			
		||||
        # if is_causal:
 | 
			
		||||
        #     if key.dtype == torch.uint8:
 | 
			
		||||
        #         attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
 | 
			
		||||
        #     else:
 | 
			
		||||
        #         attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
 | 
			
		||||
        # elif seq_length != kv_length and seq_length <= 32:
 | 
			
		||||
        #     # todo: add scale support
 | 
			
		||||
        #     if key.dtype == torch.uint8:
 | 
			
		||||
        #         attn_output = xe_addons.sdp_fp8(query, key, value, mask)
 | 
			
		||||
        #     else:
 | 
			
		||||
        #         attn_output = xe_addons.sdp(query, key, value, mask)
 | 
			
		||||
        # else:
 | 
			
		||||
        #     if key.dtype == torch.uint8:
 | 
			
		||||
        #         attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
 | 
			
		||||
        #     else:
 | 
			
		||||
        #         attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
 | 
			
		||||
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        if is_causal:
 | 
			
		||||
            if key.dtype == torch.uint8:
 | 
			
		||||
                attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask)
 | 
			
		||||
                attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
 | 
			
		||||
            else:
 | 
			
		||||
                attn_output = xe_addons.sdp_causal(query, key, value, mask)
 | 
			
		||||
                attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
 | 
			
		||||
        elif seq_length != kv_length and seq_length <= 32:
 | 
			
		||||
            # todo: add scale support
 | 
			
		||||
            if key.dtype == torch.uint8:
 | 
			
		||||
                attn_output = xe_addons.sdp_fp8(query, key, value, mask)
 | 
			
		||||
            else:
 | 
			
		||||
                attn_output = xe_addons.sdp(query, key, value, mask)
 | 
			
		||||
        else:
 | 
			
		||||
            if key.dtype == torch.uint8:
 | 
			
		||||
                attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask)
 | 
			
		||||
                attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask, scale)
 | 
			
		||||
            else:
 | 
			
		||||
                attn_output = xe_addons.sdp_non_causal(query, key, value, mask)
 | 
			
		||||
                attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
 | 
			
		||||
 | 
			
		||||
        return attn_output
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue