LLM: support speecht5_tts (#10077)

* support speecht5_tts

* fix
This commit is contained in:
Ruonan Wang 2024-02-04 13:26:42 +08:00 committed by GitHub
parent 738275761d
commit 8e33cb0f38
2 changed files with 8 additions and 1 deletions

View file

@ -238,6 +238,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device)
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
if in_features % 64 != 0:
# now our kernel requires in_features is a multiple of 64
continue
new_linear = LowBitLinear(
in_features,
out_features,

View file

@ -478,7 +478,11 @@ class LowBitLinear(nn.Linear):
if x_2d.is_contiguous() is False:
x_2d = x_2d.contiguous()
input_seq_size = x_shape[1]
if len(x_shape) == 3:
input_seq_size = x_shape[1]
elif len(x_shape) < 3:
input_seq_size = 1
if is_training:
# training path
if x_2d.requires_grad: