From 145e8b480f8b1c54adddbc7d8a3b808ce663e0b4 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 21 Nov 2024 10:12:46 +0800 Subject: [PATCH] update batch kernel condition (#12421) --- .../llm/src/ipex_llm/transformers/low_bit_linear.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 9b941639..fc461d04 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -362,14 +362,22 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int): batch_size = x.shape[0] hard_condition = ( x.dtype in [torch.float, torch.half] - and batch_size <= 48 + and x.shape[1] % 128 == 0 and ( ( qtype in [SYM_INT4, ASYM_INT4, FP8E5, FP8E4] - and x.shape[1] % 128 == 0 + and ( + batch_size <= 48 + or ( + batch_size <= 64 + and x.shape[1] % 256 == 0 + and output_len % 64 == 0 + ) + ) ) or ( qtype in [SYM_INT8, FP4, FP6, Q4_K, Q6_K] + and batch_size <= 48 and device in ["arc", "flex", "pvc", "mtl"] and x.shape[1] % 256 == 0 and output_len % 32 == 0