padding mask on torch side (#12577)
This commit is contained in:
		
							parent
							
								
									47e90a362f
								
							
						
					
					
						commit
						e0921f80c1
					
				
					 1 changed files with 77 additions and 0 deletions
				
			
		| 
						 | 
					@ -184,3 +184,80 @@ def layer_norm_forward(self, hidden_states: torch.Tensor):
 | 
				
			||||||
            hidden_states, self.normalized_shape,
 | 
					            hidden_states, self.normalized_shape,
 | 
				
			||||||
            self.weight, self.bias, self.eps
 | 
					            self.weight, self.bias, self.eps
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device):
 | 
				
			||||||
 | 
					    max_kvs = 128
 | 
				
			||||||
 | 
					    padding_kv_length = (kv_length + max_kvs - 1) // max_kvs * max_kvs
 | 
				
			||||||
 | 
					    if mask is None:
 | 
				
			||||||
 | 
					        if is_causal:
 | 
				
			||||||
 | 
					            mask = torch.full([1, 1, seq_length, padding_kv_length], torch.finfo(dtype).min,
 | 
				
			||||||
 | 
					                              dtype=dtype, device=device)
 | 
				
			||||||
 | 
					            mask.triu_(1)
 | 
				
			||||||
 | 
					            mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
 | 
				
			||||||
 | 
					        elif seq_length != kv_length and seq_length <= 32:
 | 
				
			||||||
 | 
					            mask = None
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            mask = torch.zeros([1, 1, 1, padding_kv_length], torch.finfo(dtype).min,
 | 
				
			||||||
 | 
					                               dtype=dtype, device=device)
 | 
				
			||||||
 | 
					            mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if seq_length != kv_length and seq_length <= 32:
 | 
				
			||||||
 | 
					            mask = mask[..., :seq_length, :kv_length]
 | 
				
			||||||
 | 
					            mask = mask.expand([bsz, n_heads, seq_length, kv_length])
 | 
				
			||||||
 | 
					        elif mask.size(3) != padding_kv_length:
 | 
				
			||||||
 | 
					            new_mask = torch.empty([bsz, 1, seq_length, padding_kv_length],
 | 
				
			||||||
 | 
					                                   dtype=dtype, device=device)
 | 
				
			||||||
 | 
					            new_mask[:, :, :, :kv_length] = mask[:, 0:1, :seq_length, :kv_length]
 | 
				
			||||||
 | 
					            new_mask[:, :, :, kv_length:] = torch.finfo(dtype).min
 | 
				
			||||||
 | 
					            new_mask = new_mask.expand([bsz, n_heads, seq_length, padding_kv_length])
 | 
				
			||||||
 | 
					            mask.set_(new_mask)     # modify `mask` inplaced
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
 | 
				
			||||||
 | 
					    return mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
 | 
				
			||||||
 | 
					                                 mask: torch.Tensor = None,
 | 
				
			||||||
 | 
					                                 is_causal: bool = False, scale: float = None) -> torch.Tensor:
 | 
				
			||||||
 | 
					    bsz, n_heads, seq_length, head_dim = query.shape
 | 
				
			||||||
 | 
					    _, n_kv_heads, kv_length, _ = key.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dtype, device = query.dtype, query.device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (
 | 
				
			||||||
 | 
					        device.type == "xpu"
 | 
				
			||||||
 | 
					        and dtype in [torch.float, torch.half]
 | 
				
			||||||
 | 
					        and head_dim in [64, 80, 96, 128]
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        # prepare scale
 | 
				
			||||||
 | 
					        scale = 1 / math.sqrt(head_dim) if scale is None else scale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # prepare mask
 | 
				
			||||||
 | 
					        mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # compute
 | 
				
			||||||
 | 
					        import xe_addons
 | 
				
			||||||
 | 
					        if is_causal:
 | 
				
			||||||
 | 
					            if key.dtype == torch.uint8:
 | 
				
			||||||
 | 
					                attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
 | 
				
			||||||
 | 
					        elif seq_length != kv_length and seq_length <= 32:
 | 
				
			||||||
 | 
					            # todo: add scale support
 | 
				
			||||||
 | 
					            if key.dtype == torch.uint8:
 | 
				
			||||||
 | 
					                attn_output = xe_addons.sdp_fp8(query, key, value, mask)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                attn_output = xe_addons.sdp(query, key, value, mask)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if key.dtype == torch.uint8:
 | 
				
			||||||
 | 
					                attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return attn_output
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        mask = mask[..., :seq_length, :kv_length] if mask is not None else None
 | 
				
			||||||
 | 
					        return torch.nn.functional.scaled_dot_product_attention(
 | 
				
			||||||
 | 
					            query, key, value, mask, is_causal=is_causal, scale=scale
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue