padding mask on torch side (#12577)
This commit is contained in:
parent
47e90a362f
commit
e0921f80c1
1 changed files with 77 additions and 0 deletions
|
|
@ -184,3 +184,80 @@ def layer_norm_forward(self, hidden_states: torch.Tensor):
|
|||
hidden_states, self.normalized_shape,
|
||||
self.weight, self.bias, self.eps
|
||||
)
|
||||
|
||||
|
||||
def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device):
|
||||
max_kvs = 128
|
||||
padding_kv_length = (kv_length + max_kvs - 1) // max_kvs * max_kvs
|
||||
if mask is None:
|
||||
if is_causal:
|
||||
mask = torch.full([1, 1, seq_length, padding_kv_length], torch.finfo(dtype).min,
|
||||
dtype=dtype, device=device)
|
||||
mask.triu_(1)
|
||||
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
|
||||
elif seq_length != kv_length and seq_length <= 32:
|
||||
mask = None
|
||||
else:
|
||||
mask = torch.zeros([1, 1, 1, padding_kv_length], torch.finfo(dtype).min,
|
||||
dtype=dtype, device=device)
|
||||
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
|
||||
else:
|
||||
if seq_length != kv_length and seq_length <= 32:
|
||||
mask = mask[..., :seq_length, :kv_length]
|
||||
mask = mask.expand([bsz, n_heads, seq_length, kv_length])
|
||||
elif mask.size(3) != padding_kv_length:
|
||||
new_mask = torch.empty([bsz, 1, seq_length, padding_kv_length],
|
||||
dtype=dtype, device=device)
|
||||
new_mask[:, :, :, :kv_length] = mask[:, 0:1, :seq_length, :kv_length]
|
||||
new_mask[:, :, :, kv_length:] = torch.finfo(dtype).min
|
||||
new_mask = new_mask.expand([bsz, n_heads, seq_length, padding_kv_length])
|
||||
mask.set_(new_mask) # modify `mask` inplaced
|
||||
else:
|
||||
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
|
||||
return mask
|
||||
|
||||
|
||||
def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
is_causal: bool = False, scale: float = None) -> torch.Tensor:
|
||||
bsz, n_heads, seq_length, head_dim = query.shape
|
||||
_, n_kv_heads, kv_length, _ = key.shape
|
||||
|
||||
dtype, device = query.dtype, query.device
|
||||
|
||||
if (
|
||||
device.type == "xpu"
|
||||
and dtype in [torch.float, torch.half]
|
||||
and head_dim in [64, 80, 96, 128]
|
||||
):
|
||||
# prepare scale
|
||||
scale = 1 / math.sqrt(head_dim) if scale is None else scale
|
||||
|
||||
# prepare mask
|
||||
mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
|
||||
|
||||
# compute
|
||||
import xe_addons
|
||||
if is_causal:
|
||||
if key.dtype == torch.uint8:
|
||||
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
|
||||
else:
|
||||
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
|
||||
elif seq_length != kv_length and seq_length <= 32:
|
||||
# todo: add scale support
|
||||
if key.dtype == torch.uint8:
|
||||
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
||||
else:
|
||||
attn_output = xe_addons.sdp(query, key, value, mask)
|
||||
else:
|
||||
if key.dtype == torch.uint8:
|
||||
attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
|
||||
else:
|
||||
attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
|
||||
|
||||
return attn_output
|
||||
else:
|
||||
mask = mask[..., :seq_length, :kv_length] if mask is not None else None
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, mask, is_causal=is_causal, scale=scale
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue