feat: implement --device flag

This commit is contained in:
Ayo Ayco 2025-09-03 22:50:02 +02:00
parent b212a108df
commit f06dd4e049
2 changed files with 13 additions and 2 deletions

View file

@ -3,3 +3,4 @@ soundfile
python-vlc
tqdm
argparse
torch

14
tts.py
View file

@ -1,7 +1,8 @@
import sys
import os
from time import sleep
from time import sleep, time
import torch
import argparse
from kokoro import KPipeline
import soundfile as sf
@ -32,11 +33,17 @@ def parse_args():
default="demo/tongue-twister.txt",
help="Voice to use (pro, hot, asmr, brit)",
)
parser.add_argument(
"--device",
type=str,
default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else ("xpu" if torch.xpu.is_available() else "cpu"))),
help="Device for inference: cuda | mps | cpu",
)
return parser.parse_args()
def main():
args=parse_args()
pipeline = KPipeline(lang_code='a', device='xpu', repo_id='hexgrad/Kokoro-82M')
pipeline = KPipeline(lang_code='a', device=args.device, repo_id='hexgrad/Kokoro-82M')
voice=voices[args.voice]
if voice is None:
if args.voice is None:
@ -57,6 +64,7 @@ def main():
output_files = []
length = 0
start_time = time()
for i, (gs, ps, audio) in enumerate(generator):
output_file_name=f'outputs/{name}-{i}.wav'
os.makedirs(os.path.dirname(output_file_name), exist_ok=True)
@ -64,6 +72,8 @@ def main():
sf.write(output_file_name, audio, 24000)
print(u'\u2713', output_file_name)
length = length + 1
generation_time = time() - start_time
print(f"Generation time: {generation_time:.2f} seconds")
for i, output in enumerate(output_files):
full_path = os.path.abspath(output)