From 8ed05ba7def528d79329ea2031191f559d827d19 Mon Sep 17 00:00:00 2001 From: YaoyaoChang Date: Wed, 27 Aug 2025 18:20:42 -0700 Subject: [PATCH] update to sdpa --- demo/inference_from_file.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/demo/inference_from_file.py b/demo/inference_from_file.py index 21f3407..b26cfa5 100644 --- a/demo/inference_from_file.py +++ b/demo/inference_from_file.py @@ -176,9 +176,9 @@ def parse_args(): help="CFG (Classifier-Free Guidance) scale for generation (default: 1.3)", ) parser.add_argument( - "--use_eager", + "--use_sdpa", action="store_true", - help="Use eager attention mode instead of flash_attention_2", + help="Use sdpa attention mode instead of flash_attention_2", ) return parser.parse_args() @@ -249,12 +249,12 @@ def main(): processor = VibeVoiceProcessor.from_pretrained(args.model_path) # Load model - attn_implementation = "flash_attention_2" if not args.use_eager else "eager" + attn_implementation = "flash_attention_2" if not args.use_sdpa else "sdpa" model = VibeVoiceForConditionalGenerationInference.from_pretrained( args.model_path, torch_dtype=torch.bfloat16, device_map='cuda', - attn_implementation=attn_implementation # flash_attention_2 is recommended, eager may lead to lower audio quality + attn_implementation=attn_implementation # flash_attention_2 is recommended ) model.eval()