From 5e967205ac87e4e54682fb399a2e490cddd85c6d Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 2 Jul 2024 16:23:53 +0800 Subject: [PATCH] remove the code converts input to fp16 before calling batch forward kernel (#11489) --- .../ipex_llm/transformers/low_bit_linear.py | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 1d632d6f..7bc0f7ce 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -743,25 +743,24 @@ class LowBitLinear(nn.Linear): do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024 if do_empty_cache: torch.xpu.empty_cache() - if self.conver_to_half and x_2d.shape[0] > 1 and x_2d.dtype == torch.float32 and \ - not use_xmx(x_2d, self.weight.qtype): + + if use_batch_forward(x_2d, self.weight.qtype, self.out_len): + import xe_batch + result = xe_batch.batch_forward(x_2d, self.weight.data, self.weight.qtype) + elif ( + self.conver_to_half + and x_2d.shape[0] > 1 + and x_2d.dtype == torch.float32 + and not use_xmx(x_2d, self.weight.qtype) + ): x_2d = x_2d.half() - if use_batch_forward(x_2d, self.weight.qtype, self.out_len): - import xe_batch - result = xe_batch.batch_forward(x_2d, self.weight.data, - self.weight.qtype) - else: - result = xe_linear.forward_new(x_2d, self.weight.data, self.weight.qtype, - input_seq_size) + result = xe_linear.forward_new(x_2d, self.weight.data, + self.weight.qtype, input_seq_size) result = result.to(x.dtype) else: - if use_batch_forward(x_2d, self.weight.qtype, self.out_len): - import xe_batch - result = xe_batch.batch_forward(x_2d, self.weight.data, - self.weight.qtype) - else: - result = xe_linear.forward_new(x_2d, self.weight.data, self.weight.qtype, - input_seq_size) + result = xe_linear.forward_new(x_2d, self.weight.data, + self.weight.qtype, input_seq_size) + if do_empty_cache: torch.xpu.empty_cache() result = result.view(new_shape)