feat: implement --device flag
This commit is contained in:
parent
b212a108df
commit
f06dd4e049
2 changed files with 13 additions and 2 deletions
|
@ -3,3 +3,4 @@ soundfile
|
||||||
python-vlc
|
python-vlc
|
||||||
tqdm
|
tqdm
|
||||||
argparse
|
argparse
|
||||||
|
torch
|
||||||
|
|
14
tts.py
14
tts.py
|
@ -1,7 +1,8 @@
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
from time import sleep
|
from time import sleep, time
|
||||||
|
|
||||||
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
from kokoro import KPipeline
|
from kokoro import KPipeline
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
@ -32,11 +33,17 @@ def parse_args():
|
||||||
default="demo/tongue-twister.txt",
|
default="demo/tongue-twister.txt",
|
||||||
help="Voice to use (pro, hot, asmr, brit)",
|
help="Voice to use (pro, hot, asmr, brit)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else ("xpu" if torch.xpu.is_available() else "cpu"))),
|
||||||
|
help="Device for inference: cuda | mps | cpu",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args=parse_args()
|
args=parse_args()
|
||||||
pipeline = KPipeline(lang_code='a', device='xpu', repo_id='hexgrad/Kokoro-82M')
|
pipeline = KPipeline(lang_code='a', device=args.device, repo_id='hexgrad/Kokoro-82M')
|
||||||
voice=voices[args.voice]
|
voice=voices[args.voice]
|
||||||
if voice is None:
|
if voice is None:
|
||||||
if args.voice is None:
|
if args.voice is None:
|
||||||
|
@ -57,6 +64,7 @@ def main():
|
||||||
output_files = []
|
output_files = []
|
||||||
length = 0
|
length = 0
|
||||||
|
|
||||||
|
start_time = time()
|
||||||
for i, (gs, ps, audio) in enumerate(generator):
|
for i, (gs, ps, audio) in enumerate(generator):
|
||||||
output_file_name=f'outputs/{name}-{i}.wav'
|
output_file_name=f'outputs/{name}-{i}.wav'
|
||||||
os.makedirs(os.path.dirname(output_file_name), exist_ok=True)
|
os.makedirs(os.path.dirname(output_file_name), exist_ok=True)
|
||||||
|
@ -64,6 +72,8 @@ def main():
|
||||||
sf.write(output_file_name, audio, 24000)
|
sf.write(output_file_name, audio, 24000)
|
||||||
print(u'\u2713', output_file_name)
|
print(u'\u2713', output_file_name)
|
||||||
length = length + 1
|
length = length + 1
|
||||||
|
generation_time = time() - start_time
|
||||||
|
print(f"Generation time: {generation_time:.2f} seconds")
|
||||||
|
|
||||||
for i, output in enumerate(output_files):
|
for i, output in enumerate(output_files):
|
||||||
full_path = os.path.abspath(output)
|
full_path = os.path.abspath(output)
|
||||||
|
|
Loading…
Reference in a new issue