support CPU

This commit is contained in:
YaoyaoChang 2025-09-01 16:32:50 +08:00
parent 086cd6aa6c
commit 3074f898ca

View file

@ -246,23 +246,30 @@ def main():
processor = VibeVoiceProcessor.from_pretrained(args.model_path)
# Load model
torch_dtype, device_map, attn_implementation = torch.bfloat16, 'cuda', 'flash_attention_2'
if args.device == 'cpu':
torch_dtype, device_map, attn_implementation = torch.float32, 'cpu', 'sdpa'
print(f"Using device: {args.device}, torch_dtype: {torch_dtype}, attn_implementation: {attn_implementation}")
try:
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
device_map='cuda',
attn_implementation='flash_attention_2' # flash_attention_2 is recommended
torch_dtype=torch_dtype,
device_map=device_map,
attn_implementation=attn_implementation
)
except Exception as e:
print(f"[ERROR] : {type(e).__name__}: {e}")
print(traceback.format_exc())
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
device_map='cuda',
attn_implementation='sdpa'
)
if attn_implementation == 'flash_attention_2':
print(f"[ERROR] : {type(e).__name__}: {e}")
print(traceback.format_exc())
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=torch_dtype,
device_map=device_map,
attn_implementation='sdpa'
)
else:
raise e
model.eval()
model.set_ddpm_inference_steps(num_steps=10)