parent
738275761d
commit
8e33cb0f38
2 changed files with 8 additions and 1 deletions
|
|
@ -238,6 +238,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
.to(device)
|
.to(device)
|
||||||
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
|
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
|
||||||
|
if in_features % 64 != 0:
|
||||||
|
# now our kernel requires in_features is a multiple of 64
|
||||||
|
continue
|
||||||
new_linear = LowBitLinear(
|
new_linear = LowBitLinear(
|
||||||
in_features,
|
in_features,
|
||||||
out_features,
|
out_features,
|
||||||
|
|
|
||||||
|
|
@ -478,7 +478,11 @@ class LowBitLinear(nn.Linear):
|
||||||
if x_2d.is_contiguous() is False:
|
if x_2d.is_contiguous() is False:
|
||||||
x_2d = x_2d.contiguous()
|
x_2d = x_2d.contiguous()
|
||||||
|
|
||||||
input_seq_size = x_shape[1]
|
if len(x_shape) == 3:
|
||||||
|
input_seq_size = x_shape[1]
|
||||||
|
elif len(x_shape) < 3:
|
||||||
|
input_seq_size = 1
|
||||||
|
|
||||||
if is_training:
|
if is_training:
|
||||||
# training path
|
# training path
|
||||||
if x_2d.requires_grad:
|
if x_2d.requires_grad:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue