padding mask on torch side (#12577)

This commit is contained in:
Yishuo Wang 2024-12-19 10:53:02 +08:00 committed by GitHub
parent 47e90a362f
commit e0921f80c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -184,3 +184,80 @@ def layer_norm_forward(self, hidden_states: torch.Tensor):
hidden_states, self.normalized_shape,
self.weight, self.bias, self.eps
)
def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device):
max_kvs = 128
padding_kv_length = (kv_length + max_kvs - 1) // max_kvs * max_kvs
if mask is None:
if is_causal:
mask = torch.full([1, 1, seq_length, padding_kv_length], torch.finfo(dtype).min,
dtype=dtype, device=device)
mask.triu_(1)
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
elif seq_length != kv_length and seq_length <= 32:
mask = None
else:
mask = torch.zeros([1, 1, 1, padding_kv_length], torch.finfo(dtype).min,
dtype=dtype, device=device)
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
else:
if seq_length != kv_length and seq_length <= 32:
mask = mask[..., :seq_length, :kv_length]
mask = mask.expand([bsz, n_heads, seq_length, kv_length])
elif mask.size(3) != padding_kv_length:
new_mask = torch.empty([bsz, 1, seq_length, padding_kv_length],
dtype=dtype, device=device)
new_mask[:, :, :, :kv_length] = mask[:, 0:1, :seq_length, :kv_length]
new_mask[:, :, :, kv_length:] = torch.finfo(dtype).min
new_mask = new_mask.expand([bsz, n_heads, seq_length, padding_kv_length])
mask.set_(new_mask) # modify `mask` inplaced
else:
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
return mask
def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
mask: torch.Tensor = None,
is_causal: bool = False, scale: float = None) -> torch.Tensor:
bsz, n_heads, seq_length, head_dim = query.shape
_, n_kv_heads, kv_length, _ = key.shape
dtype, device = query.dtype, query.device
if (
device.type == "xpu"
and dtype in [torch.float, torch.half]
and head_dim in [64, 80, 96, 128]
):
# prepare scale
scale = 1 / math.sqrt(head_dim) if scale is None else scale
# prepare mask
mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
# compute
import xe_addons
if is_causal:
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
else:
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
elif seq_length != kv_length and seq_length <= 32:
# todo: add scale support
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
else:
attn_output = xe_addons.sdp(query, key, value, mask)
else:
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
else:
attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
return attn_output
else:
mask = mask[..., :seq_length, :kv_length] if mask is not None else None
return torch.nn.functional.scaled_dot_product_attention(
query, key, value, mask, is_causal=is_causal, scale=scale
)