feat: implement --device flag

This commit is contained in:
Ayo Ayco 2025-09-03 22:50:02 +02:00
parent b212a108df
commit f06dd4e049
2 changed files with 13 additions and 2 deletions

View file

@ -3,3 +3,4 @@ soundfile
python-vlc python-vlc
tqdm tqdm
argparse argparse
torch

14
tts.py
View file

@ -1,7 +1,8 @@
import sys import sys
import os import os
from time import sleep from time import sleep, time
import torch
import argparse import argparse
from kokoro import KPipeline from kokoro import KPipeline
import soundfile as sf import soundfile as sf
@ -32,11 +33,17 @@ def parse_args():
default="demo/tongue-twister.txt", default="demo/tongue-twister.txt",
help="Voice to use (pro, hot, asmr, brit)", help="Voice to use (pro, hot, asmr, brit)",
) )
parser.add_argument(
"--device",
type=str,
default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else ("xpu" if torch.xpu.is_available() else "cpu"))),
help="Device for inference: cuda | mps | cpu",
)
return parser.parse_args() return parser.parse_args()
def main(): def main():
args=parse_args() args=parse_args()
pipeline = KPipeline(lang_code='a', device='xpu', repo_id='hexgrad/Kokoro-82M') pipeline = KPipeline(lang_code='a', device=args.device, repo_id='hexgrad/Kokoro-82M')
voice=voices[args.voice] voice=voices[args.voice]
if voice is None: if voice is None:
if args.voice is None: if args.voice is None:
@ -57,6 +64,7 @@ def main():
output_files = [] output_files = []
length = 0 length = 0
start_time = time()
for i, (gs, ps, audio) in enumerate(generator): for i, (gs, ps, audio) in enumerate(generator):
output_file_name=f'outputs/{name}-{i}.wav' output_file_name=f'outputs/{name}-{i}.wav'
os.makedirs(os.path.dirname(output_file_name), exist_ok=True) os.makedirs(os.path.dirname(output_file_name), exist_ok=True)
@ -64,6 +72,8 @@ def main():
sf.write(output_file_name, audio, 24000) sf.write(output_file_name, audio, 24000)
print(u'\u2713', output_file_name) print(u'\u2713', output_file_name)
length = length + 1 length = length + 1
generation_time = time() - start_time
print(f"Generation time: {generation_time:.2f} seconds")
for i, output in enumerate(output_files): for i, output in enumerate(output_files):
full_path = os.path.abspath(output) full_path = os.path.abspath(output)