remove the code converts input to fp16 before calling batch forward kernel (#11489)

This commit is contained in:
Yishuo Wang 2024-07-02 16:23:53 +08:00 committed by GitHub
parent 1638573f56
commit 5e967205ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -743,25 +743,24 @@ class LowBitLinear(nn.Linear):
do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024
if do_empty_cache:
torch.xpu.empty_cache()
if self.conver_to_half and x_2d.shape[0] > 1 and x_2d.dtype == torch.float32 and \
not use_xmx(x_2d, self.weight.qtype):
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
import xe_batch
result = xe_batch.batch_forward(x_2d, self.weight.data, self.weight.qtype)
elif (
self.conver_to_half
and x_2d.shape[0] > 1
and x_2d.dtype == torch.float32
and not use_xmx(x_2d, self.weight.qtype)
):
x_2d = x_2d.half()
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
import xe_batch
result = xe_batch.batch_forward(x_2d, self.weight.data,
self.weight.qtype)
else:
result = xe_linear.forward_new(x_2d, self.weight.data, self.weight.qtype,
input_seq_size)
result = xe_linear.forward_new(x_2d, self.weight.data,
self.weight.qtype, input_seq_size)
result = result.to(x.dtype)
else:
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
import xe_batch
result = xe_batch.batch_forward(x_2d, self.weight.data,
self.weight.qtype)
else:
result = xe_linear.forward_new(x_2d, self.weight.data, self.weight.qtype,
input_seq_size)
result = xe_linear.forward_new(x_2d, self.weight.data,
self.weight.qtype, input_seq_size)
if do_empty_cache:
torch.xpu.empty_cache()
result = result.view(new_shape)