update to sdpa
This commit is contained in:
parent
2b75b745a4
commit
8ed05ba7de
1 changed files with 4 additions and 4 deletions
|
@ -176,9 +176,9 @@ def parse_args():
|
|||
help="CFG (Classifier-Free Guidance) scale for generation (default: 1.3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_eager",
|
||||
"--use_sdpa",
|
||||
action="store_true",
|
||||
help="Use eager attention mode instead of flash_attention_2",
|
||||
help="Use sdpa attention mode instead of flash_attention_2",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
@ -249,12 +249,12 @@ def main():
|
|||
processor = VibeVoiceProcessor.from_pretrained(args.model_path)
|
||||
|
||||
# Load model
|
||||
attn_implementation = "flash_attention_2" if not args.use_eager else "eager"
|
||||
attn_implementation = "flash_attention_2" if not args.use_sdpa else "sdpa"
|
||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map='cuda',
|
||||
attn_implementation=attn_implementation # flash_attention_2 is recommended, eager may lead to lower audio quality
|
||||
attn_implementation=attn_implementation # flash_attention_2 is recommended
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
|
Loading…
Reference in a new issue