add args to use_eager
This commit is contained in:
parent
f3df4ae1b7
commit
056bb5b0fa
2 changed files with 9 additions and 1 deletions
|
@ -175,6 +175,11 @@ def parse_args():
|
||||||
default=1.3,
|
default=1.3,
|
||||||
help="CFG (Classifier-Free Guidance) scale for generation (default: 1.3)",
|
help="CFG (Classifier-Free Guidance) scale for generation (default: 1.3)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_eager",
|
||||||
|
action="store_true",
|
||||||
|
help="Use eager attention mode instead of flash_attention_2",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -244,11 +249,12 @@ def main():
|
||||||
processor = VibeVoiceProcessor.from_pretrained(args.model_path)
|
processor = VibeVoiceProcessor.from_pretrained(args.model_path)
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
|
attn_implementation = "flash_attention_2" if not args.use_eager else "eager"
|
||||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
args.model_path,
|
args.model_path,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device_map='cuda',
|
device_map='cuda',
|
||||||
attn_implementation="flash_attention_2" # we only test flash_attention_2
|
attn_implementation=attn_implementation # flash_attention_2 is recommended, eager may lead to lower audio quality
|
||||||
)
|
)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
2
demo/text_examples/2p_short.txt
Normal file
2
demo/text_examples/2p_short.txt
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
Speaker 1: I heard there’s big news in TTS lately?
|
||||||
|
Speaker 2: Yes! Microsoft Research just open-sourced VibeVoice. The model can generate speech up to 90 minutes long, with smooth delivery and rich emotion — it’s absolutely amazing.
|
Loading…
Reference in a new issue