remove the code converts input to fp16 before calling batch forward kernel (#11489)
This commit is contained in:
parent
1638573f56
commit
5e967205ac
1 changed files with 15 additions and 16 deletions
|
|
@ -743,25 +743,24 @@ class LowBitLinear(nn.Linear):
|
||||||
do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024
|
do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024
|
||||||
if do_empty_cache:
|
if do_empty_cache:
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
if self.conver_to_half and x_2d.shape[0] > 1 and x_2d.dtype == torch.float32 and \
|
|
||||||
not use_xmx(x_2d, self.weight.qtype):
|
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
|
||||||
|
import xe_batch
|
||||||
|
result = xe_batch.batch_forward(x_2d, self.weight.data, self.weight.qtype)
|
||||||
|
elif (
|
||||||
|
self.conver_to_half
|
||||||
|
and x_2d.shape[0] > 1
|
||||||
|
and x_2d.dtype == torch.float32
|
||||||
|
and not use_xmx(x_2d, self.weight.qtype)
|
||||||
|
):
|
||||||
x_2d = x_2d.half()
|
x_2d = x_2d.half()
|
||||||
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
|
result = xe_linear.forward_new(x_2d, self.weight.data,
|
||||||
import xe_batch
|
self.weight.qtype, input_seq_size)
|
||||||
result = xe_batch.batch_forward(x_2d, self.weight.data,
|
|
||||||
self.weight.qtype)
|
|
||||||
else:
|
|
||||||
result = xe_linear.forward_new(x_2d, self.weight.data, self.weight.qtype,
|
|
||||||
input_seq_size)
|
|
||||||
result = result.to(x.dtype)
|
result = result.to(x.dtype)
|
||||||
else:
|
else:
|
||||||
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
|
result = xe_linear.forward_new(x_2d, self.weight.data,
|
||||||
import xe_batch
|
self.weight.qtype, input_seq_size)
|
||||||
result = xe_batch.batch_forward(x_2d, self.weight.data,
|
|
||||||
self.weight.qtype)
|
|
||||||
else:
|
|
||||||
result = xe_linear.forward_new(x_2d, self.weight.data, self.weight.qtype,
|
|
||||||
input_seq_size)
|
|
||||||
if do_empty_cache:
|
if do_empty_cache:
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
result = result.view(new_shape)
|
result = result.view(new_shape)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue