try sdpa if error in flash_attention_2
This commit is contained in:
parent
64241e29ca
commit
ac0104c65e
1 changed files with 18 additions and 12 deletions
|
@ -1,6 +1,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import traceback
|
||||||
from typing import List, Tuple, Union, Dict, Any
|
from typing import List, Tuple, Union, Dict, Any
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
|
@ -175,11 +176,6 @@ def parse_args():
|
||||||
default=1.3,
|
default=1.3,
|
||||||
help="CFG (Classifier-Free Guidance) scale for generation (default: 1.3)",
|
help="CFG (Classifier-Free Guidance) scale for generation (default: 1.3)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--use_sdpa",
|
|
||||||
action="store_true",
|
|
||||||
help="Use sdpa attention mode instead of flash_attention_2",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -249,13 +245,23 @@ def main():
|
||||||
processor = VibeVoiceProcessor.from_pretrained(args.model_path)
|
processor = VibeVoiceProcessor.from_pretrained(args.model_path)
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
attn_implementation = "flash_attention_2" if not args.use_sdpa else "sdpa"
|
try:
|
||||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
args.model_path,
|
args.model_path,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device_map='cuda',
|
device_map='cuda',
|
||||||
attn_implementation=attn_implementation # flash_attention_2 is recommended
|
attn_implementation='flash_attention_2' # flash_attention_2 is recommended
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] : {type(e).__name__}: {e}")
|
||||||
|
print(traceback.format_exc())
|
||||||
|
print("Error loading model, try sdpa.")
|
||||||
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
|
args.model_path,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map='cuda',
|
||||||
|
attn_implementation='sdpa'
|
||||||
|
)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
model.set_ddpm_inference_steps(num_steps=10)
|
model.set_ddpm_inference_steps(num_steps=10)
|
||||||
|
|
Loading…
Reference in a new issue