support CPU
This commit is contained in:
parent
086cd6aa6c
commit
3074f898ca
1 changed files with 19 additions and 12 deletions
|
@ -246,23 +246,30 @@ def main():
|
||||||
processor = VibeVoiceProcessor.from_pretrained(args.model_path)
|
processor = VibeVoiceProcessor.from_pretrained(args.model_path)
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
|
torch_dtype, device_map, attn_implementation = torch.bfloat16, 'cuda', 'flash_attention_2'
|
||||||
|
if args.device == 'cpu':
|
||||||
|
torch_dtype, device_map, attn_implementation = torch.float32, 'cpu', 'sdpa'
|
||||||
|
print(f"Using device: {args.device}, torch_dtype: {torch_dtype}, attn_implementation: {attn_implementation}")
|
||||||
try:
|
try:
|
||||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
args.model_path,
|
args.model_path,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch_dtype,
|
||||||
device_map='cuda',
|
device_map=device_map,
|
||||||
attn_implementation='flash_attention_2' # flash_attention_2 is recommended
|
attn_implementation=attn_implementation
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR] : {type(e).__name__}: {e}")
|
if attn_implementation == 'flash_attention_2':
|
||||||
print(traceback.format_exc())
|
print(f"[ERROR] : {type(e).__name__}: {e}")
|
||||||
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
|
print(traceback.format_exc())
|
||||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
|
||||||
args.model_path,
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
torch_dtype=torch.bfloat16,
|
args.model_path,
|
||||||
device_map='cuda',
|
torch_dtype=torch_dtype,
|
||||||
attn_implementation='sdpa'
|
device_map=device_map,
|
||||||
)
|
attn_implementation='sdpa'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
model.set_ddpm_inference_steps(num_steps=10)
|
model.set_ddpm_inference_steps(num_steps=10)
|
||||||
|
|
Loading…
Reference in a new issue