From 8e33cb0f389100900be7b857467e20ff1706ca22 Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Sun, 4 Feb 2024 13:26:42 +0800 Subject: [PATCH] LLM: support speecht5_tts (#10077) * support speecht5_tts * fix --- python/llm/src/bigdl/llm/transformers/convert.py | 3 +++ python/llm/src/bigdl/llm/transformers/low_bit_linear.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index de9127ed..38910d3d 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -238,6 +238,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ .to(device) elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]: + if in_features % 64 != 0: + # now our kernel requires in_features is a multiple of 64 + continue new_linear = LowBitLinear( in_features, out_features, diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index ffeb6cba..08dbab8f 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -478,7 +478,11 @@ class LowBitLinear(nn.Linear): if x_2d.is_contiguous() is False: x_2d = x_2d.contiguous() - input_seq_size = x_shape[1] + if len(x_shape) == 3: + input_seq_size = x_shape[1] + elif len(x_shape) < 3: + input_seq_size = 1 + if is_training: # training path if x_2d.requires_grad: