try sdpa if error in flash_attention_2

This commit is contained in:
YaoyaoChang 2025-08-27 18:42:29 -07:00
parent 64241e29ca
commit ac0104c65e

View file

@ -1,6 +1,7 @@
import argparse
import os
import re
import traceback
from typing import List, Tuple, Union, Dict, Any
import time
import torch
@ -175,11 +176,6 @@ def parse_args():
default=1.3,
help="CFG (Classifier-Free Guidance) scale for generation (default: 1.3)",
)
parser.add_argument(
"--use_sdpa",
action="store_true",
help="Use sdpa attention mode instead of flash_attention_2",
)
return parser.parse_args()
@ -249,13 +245,23 @@ def main():
processor = VibeVoiceProcessor.from_pretrained(args.model_path)
# Load model
attn_implementation = "flash_attention_2" if not args.use_sdpa else "sdpa"
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
device_map='cuda',
attn_implementation=attn_implementation # flash_attention_2 is recommended
)
try:
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
device_map='cuda',
attn_implementation='flash_attention_2' # flash_attention_2 is recommended
)
except Exception as e:
print(f"[ERROR] : {type(e).__name__}: {e}")
print(traceback.format_exc())
print("Error loading model, try sdpa.")
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
device_map='cuda',
attn_implementation='sdpa'
)
model.eval()
model.set_ddpm_inference_steps(num_steps=10)