diff --git a/demo/inference_from_file.py b/demo/inference_from_file.py index b26cfa5..a15aaf8 100644 --- a/demo/inference_from_file.py +++ b/demo/inference_from_file.py @@ -1,6 +1,7 @@ import argparse import os import re +import traceback from typing import List, Tuple, Union, Dict, Any import time import torch @@ -175,11 +176,6 @@ def parse_args(): default=1.3, help="CFG (Classifier-Free Guidance) scale for generation (default: 1.3)", ) - parser.add_argument( - "--use_sdpa", - action="store_true", - help="Use sdpa attention mode instead of flash_attention_2", - ) return parser.parse_args() @@ -249,13 +245,23 @@ def main(): processor = VibeVoiceProcessor.from_pretrained(args.model_path) # Load model - attn_implementation = "flash_attention_2" if not args.use_sdpa else "sdpa" - model = VibeVoiceForConditionalGenerationInference.from_pretrained( - args.model_path, - torch_dtype=torch.bfloat16, - device_map='cuda', - attn_implementation=attn_implementation # flash_attention_2 is recommended - ) + try: + model = VibeVoiceForConditionalGenerationInference.from_pretrained( + args.model_path, + torch_dtype=torch.bfloat16, + device_map='cuda', + attn_implementation='flash_attention_2' # flash_attention_2 is recommended + ) + except Exception as e: + print(f"[ERROR] : {type(e).__name__}: {e}") + print(traceback.format_exc()) + print("Error loading model, try sdpa.") + model = VibeVoiceForConditionalGenerationInference.from_pretrained( + args.model_path, + torch_dtype=torch.bfloat16, + device_map='cuda', + attn_implementation='sdpa' + ) model.eval() model.set_ddpm_inference_steps(num_steps=10)