update sdp support (#12847)
This commit is contained in:
		
							parent
							
								
									93c10be762
								
							
						
					
					
						commit
						aee2db30f9
					
				
					 1 changed files with 1 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -230,7 +230,7 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
 | 
			
		|||
    if (
 | 
			
		||||
        device.type == "xpu"
 | 
			
		||||
        and dtype in [torch.float, torch.half]
 | 
			
		||||
        and head_dim in [64, 80, 96, 128]
 | 
			
		||||
        and head_dim in [64, 80, 96, 128, 192, 256]
 | 
			
		||||
    ):
 | 
			
		||||
        # prepare scale
 | 
			
		||||
        scale = 1 / math.sqrt(head_dim) if scale is None else scale
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue