parent
e3130a06ed
commit
531bef2810
1 changed files with 2 additions and 1 deletions
|
|
@ -654,7 +654,8 @@ class LowBitLinear(nn.Linear):
|
||||||
else:
|
else:
|
||||||
w = self.weight.data
|
w = self.weight.data
|
||||||
|
|
||||||
if use_batch_forward(x_2d, self.weight.qtype, self.out_len) and self.conver_to_half:
|
if use_batch_forward(x_2d, self.weight.qtype, self.out_len) and \
|
||||||
|
(x_2d.dtype == torch.half or self.conver_to_half):
|
||||||
import xe_batch
|
import xe_batch
|
||||||
result = xe_batch.batch_forward(x_2d, w, self.qtype)
|
result = xe_batch.batch_forward(x_2d, w, self.qtype)
|
||||||
elif not is_training and self.conver_to_half \
|
elif not is_training and self.conver_to_half \
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue