Support new fp8 e4m3 (#11158)
This commit is contained in:
		
							parent
							
								
									8e25de1126
								
							
						
					
					
						commit
						e29e2f1c78
					
				
					 1 changed files with 2 additions and 1 deletions
				
			
		| 
						 | 
					@ -329,13 +329,14 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
 | 
				
			||||||
        and output_len % 32 == 0
 | 
					        and output_len % 32 == 0
 | 
				
			||||||
        and device in ["arc", "flex", "pvc", "mtl"]
 | 
					        and device in ["arc", "flex", "pvc", "mtl"]
 | 
				
			||||||
        and qtype in [SYM_INT4, ASYM_INT4, SYM_INT8, FP4,
 | 
					        and qtype in [SYM_INT4, ASYM_INT4, SYM_INT8, FP4,
 | 
				
			||||||
                      FP8E5, FP6]
 | 
					                      FP8E5, FP6, FP8E4]
 | 
				
			||||||
        and batch_size <= 64
 | 
					        and batch_size <= 64
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    if hard_condition:
 | 
					    if hard_condition:
 | 
				
			||||||
        return (
 | 
					        return (
 | 
				
			||||||
            batch_size > 1
 | 
					            batch_size > 1
 | 
				
			||||||
            or (device in ["arc", "flex"] and qtype in [SYM_INT8, FP4])
 | 
					            or (device in ["arc", "flex"] and qtype in [SYM_INT8, FP4])
 | 
				
			||||||
 | 
					            or (device in ["arc", "flex", "mtl"] and qtype in [FP8E4])
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    return False
 | 
					    return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue