gradio support sdpa
This commit is contained in:
parent
999020e7c4
commit
174e53fc04
1 changed files with 17 additions and 6 deletions
|
@ -54,12 +54,23 @@ class VibeVoiceDemo:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
try:
|
||||||
self.model_path,
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
torch_dtype=torch.bfloat16,
|
self.model_path,
|
||||||
device_map='cuda',
|
torch_dtype=torch.bfloat16,
|
||||||
attn_implementation="flash_attention_2",
|
device_map='cuda',
|
||||||
)
|
attn_implementation='flash_attention_2' # flash_attention_2 is recommended
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] : {type(e).__name__}: {e}")
|
||||||
|
print(traceback.format_exc())
|
||||||
|
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
|
||||||
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map='cuda',
|
||||||
|
attn_implementation='sdpa'
|
||||||
|
)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
# Use SDE solver by default
|
# Use SDE solver by default
|
||||||
|
|
Loading…
Reference in a new issue