padding mask on torch side (#12577)
This commit is contained in:
parent
47e90a362f
commit
e0921f80c1
1 changed files with 77 additions and 0 deletions
|
|
@ -184,3 +184,80 @@ def layer_norm_forward(self, hidden_states: torch.Tensor):
|
||||||
hidden_states, self.normalized_shape,
|
hidden_states, self.normalized_shape,
|
||||||
self.weight, self.bias, self.eps
|
self.weight, self.bias, self.eps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device):
|
||||||
|
max_kvs = 128
|
||||||
|
padding_kv_length = (kv_length + max_kvs - 1) // max_kvs * max_kvs
|
||||||
|
if mask is None:
|
||||||
|
if is_causal:
|
||||||
|
mask = torch.full([1, 1, seq_length, padding_kv_length], torch.finfo(dtype).min,
|
||||||
|
dtype=dtype, device=device)
|
||||||
|
mask.triu_(1)
|
||||||
|
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
|
||||||
|
elif seq_length != kv_length and seq_length <= 32:
|
||||||
|
mask = None
|
||||||
|
else:
|
||||||
|
mask = torch.zeros([1, 1, 1, padding_kv_length], torch.finfo(dtype).min,
|
||||||
|
dtype=dtype, device=device)
|
||||||
|
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
|
||||||
|
else:
|
||||||
|
if seq_length != kv_length and seq_length <= 32:
|
||||||
|
mask = mask[..., :seq_length, :kv_length]
|
||||||
|
mask = mask.expand([bsz, n_heads, seq_length, kv_length])
|
||||||
|
elif mask.size(3) != padding_kv_length:
|
||||||
|
new_mask = torch.empty([bsz, 1, seq_length, padding_kv_length],
|
||||||
|
dtype=dtype, device=device)
|
||||||
|
new_mask[:, :, :, :kv_length] = mask[:, 0:1, :seq_length, :kv_length]
|
||||||
|
new_mask[:, :, :, kv_length:] = torch.finfo(dtype).min
|
||||||
|
new_mask = new_mask.expand([bsz, n_heads, seq_length, padding_kv_length])
|
||||||
|
mask.set_(new_mask) # modify `mask` inplaced
|
||||||
|
else:
|
||||||
|
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
|
||||||
|
mask: torch.Tensor = None,
|
||||||
|
is_causal: bool = False, scale: float = None) -> torch.Tensor:
|
||||||
|
bsz, n_heads, seq_length, head_dim = query.shape
|
||||||
|
_, n_kv_heads, kv_length, _ = key.shape
|
||||||
|
|
||||||
|
dtype, device = query.dtype, query.device
|
||||||
|
|
||||||
|
if (
|
||||||
|
device.type == "xpu"
|
||||||
|
and dtype in [torch.float, torch.half]
|
||||||
|
and head_dim in [64, 80, 96, 128]
|
||||||
|
):
|
||||||
|
# prepare scale
|
||||||
|
scale = 1 / math.sqrt(head_dim) if scale is None else scale
|
||||||
|
|
||||||
|
# prepare mask
|
||||||
|
mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
|
||||||
|
|
||||||
|
# compute
|
||||||
|
import xe_addons
|
||||||
|
if is_causal:
|
||||||
|
if key.dtype == torch.uint8:
|
||||||
|
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
|
||||||
|
else:
|
||||||
|
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
|
||||||
|
elif seq_length != kv_length and seq_length <= 32:
|
||||||
|
# todo: add scale support
|
||||||
|
if key.dtype == torch.uint8:
|
||||||
|
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
||||||
|
else:
|
||||||
|
attn_output = xe_addons.sdp(query, key, value, mask)
|
||||||
|
else:
|
||||||
|
if key.dtype == torch.uint8:
|
||||||
|
attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
|
||||||
|
else:
|
||||||
|
attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
else:
|
||||||
|
mask = mask[..., :seq_length, :kv_length] if mask is not None else None
|
||||||
|
return torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query, key, value, mask, is_causal=is_causal, scale=scale
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue