parent
738275761d
commit
8e33cb0f38
2 changed files with 8 additions and 1 deletions
|
|
@ -238,6 +238,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||
.to(device)
|
||||
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
|
||||
if in_features % 64 != 0:
|
||||
# now our kernel requires in_features is a multiple of 64
|
||||
continue
|
||||
new_linear = LowBitLinear(
|
||||
in_features,
|
||||
out_features,
|
||||
|
|
|
|||
|
|
@ -478,7 +478,11 @@ class LowBitLinear(nn.Linear):
|
|||
if x_2d.is_contiguous() is False:
|
||||
x_2d = x_2d.contiguous()
|
||||
|
||||
input_seq_size = x_shape[1]
|
||||
if len(x_shape) == 3:
|
||||
input_seq_size = x_shape[1]
|
||||
elif len(x_shape) < 3:
|
||||
input_seq_size = 1
|
||||
|
||||
if is_training:
|
||||
# training path
|
||||
if x_2d.requires_grad:
|
||||
|
|
|
|||
Loading…
Reference in a new issue