remove the code converts input to fp16 before calling batch forward kernel (#11489)
This commit is contained in:
parent
1638573f56
commit
5e967205ac
1 changed files with 15 additions and 16 deletions
|
|
@ -743,25 +743,24 @@ class LowBitLinear(nn.Linear):
|
|||
do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024
|
||||
if do_empty_cache:
|
||||
torch.xpu.empty_cache()
|
||||
if self.conver_to_half and x_2d.shape[0] > 1 and x_2d.dtype == torch.float32 and \
|
||||
not use_xmx(x_2d, self.weight.qtype):
|
||||
|
||||
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
|
||||
import xe_batch
|
||||
result = xe_batch.batch_forward(x_2d, self.weight.data, self.weight.qtype)
|
||||
elif (
|
||||
self.conver_to_half
|
||||
and x_2d.shape[0] > 1
|
||||
and x_2d.dtype == torch.float32
|
||||
and not use_xmx(x_2d, self.weight.qtype)
|
||||
):
|
||||
x_2d = x_2d.half()
|
||||
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
|
||||
import xe_batch
|
||||
result = xe_batch.batch_forward(x_2d, self.weight.data,
|
||||
self.weight.qtype)
|
||||
else:
|
||||
result = xe_linear.forward_new(x_2d, self.weight.data, self.weight.qtype,
|
||||
input_seq_size)
|
||||
result = xe_linear.forward_new(x_2d, self.weight.data,
|
||||
self.weight.qtype, input_seq_size)
|
||||
result = result.to(x.dtype)
|
||||
else:
|
||||
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
|
||||
import xe_batch
|
||||
result = xe_batch.batch_forward(x_2d, self.weight.data,
|
||||
self.weight.qtype)
|
||||
else:
|
||||
result = xe_linear.forward_new(x_2d, self.weight.data, self.weight.qtype,
|
||||
input_seq_size)
|
||||
result = xe_linear.forward_new(x_2d, self.weight.data,
|
||||
self.weight.qtype, input_seq_size)
|
||||
|
||||
if do_empty_cache:
|
||||
torch.xpu.empty_cache()
|
||||
result = result.view(new_shape)
|
||||
|
|
|
|||
Loading…
Reference in a new issue