From 174e53fc04ba5eec33350f3958604555a5a64f88 Mon Sep 17 00:00:00 2001 From: YaoyaoChang Date: Sat, 30 Aug 2025 08:22:32 -0700 Subject: [PATCH] gradio support sdpa --- demo/gradio_demo.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/demo/gradio_demo.py b/demo/gradio_demo.py index 0030045..de3410c 100644 --- a/demo/gradio_demo.py +++ b/demo/gradio_demo.py @@ -54,12 +54,23 @@ class VibeVoiceDemo: ) # Load model - self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( - self.model_path, - torch_dtype=torch.bfloat16, - device_map='cuda', - attn_implementation="flash_attention_2", - ) + try: + self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( + self.model_path, + torch_dtype=torch.bfloat16, + device_map='cuda', + attn_implementation='flash_attention_2' # flash_attention_2 is recommended + ) + except Exception as e: + print(f"[ERROR] : {type(e).__name__}: {e}") + print(traceback.format_exc()) + print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.") + self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( + self.model_path, + torch_dtype=torch.bfloat16, + device_map='cuda', + attn_implementation='sdpa' + ) self.model.eval() # Use SDE solver by default