feat: implement --device flag
This commit is contained in:
parent
b212a108df
commit
f06dd4e049
2 changed files with 13 additions and 2 deletions
|
@ -3,3 +3,4 @@ soundfile
|
|||
python-vlc
|
||||
tqdm
|
||||
argparse
|
||||
torch
|
||||
|
|
14
tts.py
14
tts.py
|
@ -1,7 +1,8 @@
|
|||
import sys
|
||||
import os
|
||||
from time import sleep
|
||||
from time import sleep, time
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
from kokoro import KPipeline
|
||||
import soundfile as sf
|
||||
|
@ -32,11 +33,17 @@ def parse_args():
|
|||
default="demo/tongue-twister.txt",
|
||||
help="Voice to use (pro, hot, asmr, brit)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else ("xpu" if torch.xpu.is_available() else "cpu"))),
|
||||
help="Device for inference: cuda | mps | cpu",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
def main():
|
||||
args=parse_args()
|
||||
pipeline = KPipeline(lang_code='a', device='xpu', repo_id='hexgrad/Kokoro-82M')
|
||||
pipeline = KPipeline(lang_code='a', device=args.device, repo_id='hexgrad/Kokoro-82M')
|
||||
voice=voices[args.voice]
|
||||
if voice is None:
|
||||
if args.voice is None:
|
||||
|
@ -57,6 +64,7 @@ def main():
|
|||
output_files = []
|
||||
length = 0
|
||||
|
||||
start_time = time()
|
||||
for i, (gs, ps, audio) in enumerate(generator):
|
||||
output_file_name=f'outputs/{name}-{i}.wav'
|
||||
os.makedirs(os.path.dirname(output_file_name), exist_ok=True)
|
||||
|
@ -64,6 +72,8 @@ def main():
|
|||
sf.write(output_file_name, audio, 24000)
|
||||
print(u'\u2713', output_file_name)
|
||||
length = length + 1
|
||||
generation_time = time() - start_time
|
||||
print(f"Generation time: {generation_time:.2f} seconds")
|
||||
|
||||
for i, output in enumerate(output_files):
|
||||
full_path = os.path.abspath(output)
|
||||
|
|
Loading…
Reference in a new issue