update to sdpa

This commit is contained in:
YaoyaoChang 2025-08-27 18:20:42 -07:00
parent 2b75b745a4
commit 8ed05ba7de

View file

@ -176,9 +176,9 @@ def parse_args():
help="CFG (Classifier-Free Guidance) scale for generation (default: 1.3)", help="CFG (Classifier-Free Guidance) scale for generation (default: 1.3)",
) )
parser.add_argument( parser.add_argument(
"--use_eager", "--use_sdpa",
action="store_true", action="store_true",
help="Use eager attention mode instead of flash_attention_2", help="Use sdpa attention mode instead of flash_attention_2",
) )
return parser.parse_args() return parser.parse_args()
@ -249,12 +249,12 @@ def main():
processor = VibeVoiceProcessor.from_pretrained(args.model_path) processor = VibeVoiceProcessor.from_pretrained(args.model_path)
# Load model # Load model
attn_implementation = "flash_attention_2" if not args.use_eager else "eager" attn_implementation = "flash_attention_2" if not args.use_sdpa else "sdpa"
model = VibeVoiceForConditionalGenerationInference.from_pretrained( model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path, args.model_path,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map='cuda', device_map='cuda',
attn_implementation=attn_implementation # flash_attention_2 is recommended, eager may lead to lower audio quality attn_implementation=attn_implementation # flash_attention_2 is recommended
) )
model.eval() model.eval()