diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index c36445ec..9b941639 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -362,16 +362,19 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int): batch_size = x.shape[0] hard_condition = ( x.dtype in [torch.float, torch.half] + and batch_size <= 48 and ( - x.shape[1] % 128 == 0 and qtype in [SYM_INT4] + ( + qtype in [SYM_INT4, ASYM_INT4, FP8E5, FP8E4] + and x.shape[1] % 128 == 0 + ) or ( - x.shape[1] % 256 == 0 - and output_len % 32 == 0 + qtype in [SYM_INT8, FP4, FP6, Q4_K, Q6_K] and device in ["arc", "flex", "pvc", "mtl"] - and qtype in [ASYM_INT4, SYM_INT8, FP4, FP8E5, FP6, FP8E4, Q4_K, Q6_K] + and x.shape[1] % 256 == 0 + and output_len % 32 == 0 ) ) - and batch_size <= 48 ) if hard_condition: return (