From 3d5fbf20695280fbc24745533b5ca8b9b6e7a00e Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 15 Nov 2024 13:47:05 +0800 Subject: [PATCH] update batch kernel condition (#12408) --- .../llm/src/ipex_llm/transformers/low_bit_linear.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index c36445ec..9b941639 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -362,16 +362,19 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int): batch_size = x.shape[0] hard_condition = ( x.dtype in [torch.float, torch.half] + and batch_size <= 48 and ( - x.shape[1] % 128 == 0 and qtype in [SYM_INT4] + ( + qtype in [SYM_INT4, ASYM_INT4, FP8E5, FP8E4] + and x.shape[1] % 128 == 0 + ) or ( - x.shape[1] % 256 == 0 - and output_len % 32 == 0 + qtype in [SYM_INT8, FP4, FP6, Q4_K, Q6_K] and device in ["arc", "flex", "pvc", "mtl"] - and qtype in [ASYM_INT4, SYM_INT8, FP4, FP8E5, FP6, FP8E4, Q4_K, Q6_K] + and x.shape[1] % 256 == 0 + and output_len % 32 == 0 ) ) - and batch_size <= 48 ) if hard_condition: return (