update to sdpa

This commit is contained in:
YaoyaoChang 2025-08-27 18:20:42 -07:00
parent 2b75b745a4
commit 8ed05ba7de

View file

@ -176,9 +176,9 @@ def parse_args():
help="CFG (Classifier-Free Guidance) scale for generation (default: 1.3)",
)
parser.add_argument(
"--use_eager",
"--use_sdpa",
action="store_true",
help="Use eager attention mode instead of flash_attention_2",
help="Use sdpa attention mode instead of flash_attention_2",
)
return parser.parse_args()
@ -249,12 +249,12 @@ def main():
processor = VibeVoiceProcessor.from_pretrained(args.model_path)
# Load model
attn_implementation = "flash_attention_2" if not args.use_eager else "eager"
attn_implementation = "flash_attention_2" if not args.use_sdpa else "sdpa"
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
device_map='cuda',
attn_implementation=attn_implementation # flash_attention_2 is recommended, eager may lead to lower audio quality
attn_implementation=attn_implementation # flash_attention_2 is recommended
)
model.eval()