use fp16_sdp when head_dim=96 (#10976)
This commit is contained in:
		
							parent
							
								
									b7f7d05a7e
								
							
						
					
					
						commit
						e753125880
					
				
					 1 changed files with 1 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -362,7 +362,7 @@ def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states):
 | 
			
		|||
    elif query_states.dtype != torch.float16:
 | 
			
		||||
        # esimd_sdp only has optimization for FP16 now
 | 
			
		||||
        return False
 | 
			
		||||
    elif head_dim != 128 and head_dim != 64:
 | 
			
		||||
    elif head_dim not in [64, 96, 128]:
 | 
			
		||||
        # esimd_sdp only support head_dim = 128 and 64 now
 | 
			
		||||
        return False
 | 
			
		||||
    elif q_len == k_len:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue