try sdpa if error in flash_attention_2

This commit is contained in:
YaoyaoChang 2025-08-27 18:42:29 -07:00
parent 64241e29ca
commit ac0104c65e

View file

@ -1,6 +1,7 @@
import argparse import argparse
import os import os
import re import re
import traceback
from typing import List, Tuple, Union, Dict, Any from typing import List, Tuple, Union, Dict, Any
import time import time
import torch import torch
@ -175,11 +176,6 @@ def parse_args():
default=1.3, default=1.3,
help="CFG (Classifier-Free Guidance) scale for generation (default: 1.3)", help="CFG (Classifier-Free Guidance) scale for generation (default: 1.3)",
) )
parser.add_argument(
"--use_sdpa",
action="store_true",
help="Use sdpa attention mode instead of flash_attention_2",
)
return parser.parse_args() return parser.parse_args()
@ -249,13 +245,23 @@ def main():
processor = VibeVoiceProcessor.from_pretrained(args.model_path) processor = VibeVoiceProcessor.from_pretrained(args.model_path)
# Load model # Load model
attn_implementation = "flash_attention_2" if not args.use_sdpa else "sdpa" try:
model = VibeVoiceForConditionalGenerationInference.from_pretrained( model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path, args.model_path,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map='cuda', device_map='cuda',
attn_implementation=attn_implementation # flash_attention_2 is recommended attn_implementation='flash_attention_2' # flash_attention_2 is recommended
) )
except Exception as e:
print(f"[ERROR] : {type(e).__name__}: {e}")
print(traceback.format_exc())
print("Error loading model, try sdpa.")
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
device_map='cuda',
attn_implementation='sdpa'
)
model.eval() model.eval()
model.set_ddpm_inference_steps(num_steps=10) model.set_ddpm_inference_steps(num_steps=10)