gradio support sdpa

This commit is contained in:
YaoyaoChang 2025-08-30 08:22:32 -07:00
parent 999020e7c4
commit 174e53fc04

View file

@ -54,12 +54,23 @@ class VibeVoiceDemo:
)
# Load model
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=torch.bfloat16,
device_map='cuda',
attn_implementation="flash_attention_2",
)
try:
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=torch.bfloat16,
device_map='cuda',
attn_implementation='flash_attention_2' # flash_attention_2 is recommended
)
except Exception as e:
print(f"[ERROR] : {type(e).__name__}: {e}")
print(traceback.format_exc())
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=torch.bfloat16,
device_map='cuda',
attn_implementation='sdpa'
)
self.model.eval()
# Use SDE solver by default