feat: pass string as argument

- disable torch warnings
This commit is contained in:
Ayo Ayco 2025-09-04 09:07:50 +02:00
parent aaae32b179
commit 6dd477022f
2 changed files with 35 additions and 15 deletions

View file

@ -69,10 +69,16 @@ Running the program without arguments will use the demo text `tongue-twister.txt
$ python tts.py # will use default arguments $ python tts.py # will use default arguments
``` ```
To run the program with an input file, use flag `--input`. You can pass a string as first argument:
```bash ```bash
$ python tts.py --input demo/tongue-twister.txt $ python tts.py "Hello world!" # will be read by the default voice
```
To run the program with an input file, use flag `--input_file`.
```bash
$ python tts.py --input_file demo/tongue-twister.txt
``` ```
### Voices ### Voices

40
tts.py
View file

@ -1,6 +1,7 @@
import sys import sys
import os import os
from time import sleep, time from time import sleep, time
import warnings
import torch import torch
import argparse import argparse
@ -9,6 +10,8 @@ import soundfile as sf
import vlc import vlc
from tqdm import tqdm from tqdm import tqdm
# Disable all warnings
warnings.filterwarnings("ignore")
# See voices: https://huggingface.co/hexgrad/Kokoro-82M/blob/main/VOICES.md # See voices: https://huggingface.co/hexgrad/Kokoro-82M/blob/main/VOICES.md
voices = { voices = {
@ -18,23 +21,32 @@ voices = {
'brit': 'bf_emma' 'brit': 'bf_emma'
} }
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Simple TTS") parser = argparse.ArgumentParser(description="Simple TTS")
parser.add_argument(
"input_text",
type=str,
nargs='?',
default="",
help="Text to read",
)
parser.add_argument( parser.add_argument(
"--voice", "--voice",
required=False,
type=str, type=str,
default="pro", default="pro",
help="Voice to use (pro, hot, asmr, brit)", help="Voice to use (pro, hot, asmr, brit)",
) )
parser.add_argument( parser.add_argument(
"--input", "--input_file",
required=False,
type=str, type=str,
default="demo/tongue-twister.txt", default="demo/tongue-twister.txt",
help="Voice to use (pro, hot, asmr, brit)", help="Path to the input text file",
) )
parser.add_argument( parser.add_argument(
"--device", "--device",
required=False,
type=str, type=str,
default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else ("xpu" if torch.xpu.is_available() else "cpu"))), default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else ("xpu" if torch.xpu.is_available() else "cpu"))),
help="Device for inference: cuda | mps | cpu", help="Device for inference: cuda | mps | cpu",
@ -50,28 +62,30 @@ def main():
voice=voices['pro'] if args.voice is None else args.voice voice=voices['pro'] if args.voice is None else args.voice
# filename argument # filename argument
file_path = args.input if args.input_text == "":
directory, file_name = os.path.split(file_path) file_path = args.input_file
directory, file_name = os.path.split(file_path)
name = '.'.join(file_name.split('.')[:-1])
file = open(file_path, "r")
text = file.read()
else:
name = "chat"
text = args.input_text
name = '.'.join(file_name.split('.')[:-1])
file = open(file_path, "r")
text = file.read()
generator = pipeline(text, voice=voice) generator = pipeline(text, voice=voice)
output_files = [] output_files = []
length = 0 length = 0
start_time = time() start_time = time()
print("Generating...")
for i, (gs, ps, audio) in enumerate(generator): for i, (gs, ps, audio) in enumerate(generator):
output_file_name=f'outputs/{name}-{voice}-{i}.wav' output_file_name=f'outputs/{name}-{voice}-{i}.wav'
os.makedirs(os.path.dirname(output_file_name), exist_ok=True) os.makedirs(os.path.dirname(output_file_name), exist_ok=True)
output_files.append(output_file_name) output_files.append(output_file_name)
sf.write(output_file_name, audio, 24000) sf.write(output_file_name, audio, 24000)
print(u'\u2713', output_file_name)
length = length + 1 length = length + 1
generation_time = time() - start_time generation_time = time() - start_time
print(f"Generation time: {generation_time:.2f} seconds") print(f"Done in {generation_time:.2f} seconds")
for i, output in enumerate(output_files): for i, output in enumerate(output_files):
full_path = os.path.abspath(output) full_path = os.path.abspath(output)
@ -84,4 +98,4 @@ def main():
sleep(duration / 100) sleep(duration / 100)
if __name__ == "__main__": if __name__ == "__main__":
main() main()