diff --git a/demo/gradio_demo.py b/demo/gradio_demo.py index 0030045..de3410c 100644 --- a/demo/gradio_demo.py +++ b/demo/gradio_demo.py @@ -54,12 +54,23 @@ class VibeVoiceDemo: ) # Load model - self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( - self.model_path, - torch_dtype=torch.bfloat16, - device_map='cuda', - attn_implementation="flash_attention_2", - ) + try: + self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( + self.model_path, + torch_dtype=torch.bfloat16, + device_map='cuda', + attn_implementation='flash_attention_2' # flash_attention_2 is recommended + ) + except Exception as e: + print(f"[ERROR] : {type(e).__name__}: {e}") + print(traceback.format_exc()) + print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.") + self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( + self.model_path, + torch_dtype=torch.bfloat16, + device_map='cuda', + attn_implementation='sdpa' + ) self.model.eval() # Use SDE solver by default