parent
e3130a06ed
commit
531bef2810
1 changed files with 2 additions and 1 deletions
|
|
@ -654,7 +654,8 @@ class LowBitLinear(nn.Linear):
|
|||
else:
|
||||
w = self.weight.data
|
||||
|
||||
if use_batch_forward(x_2d, self.weight.qtype, self.out_len) and self.conver_to_half:
|
||||
if use_batch_forward(x_2d, self.weight.qtype, self.out_len) and \
|
||||
(x_2d.dtype == torch.half or self.conver_to_half):
|
||||
import xe_batch
|
||||
result = xe_batch.batch_forward(x_2d, w, self.qtype)
|
||||
elif not is_training and self.conver_to_half \
|
||||
|
|
|
|||
Loading…
Reference in a new issue