From 3074f898ca5d5052403594dbf0006603260dafee Mon Sep 17 00:00:00 2001 From: YaoyaoChang Date: Mon, 1 Sep 2025 16:32:50 +0800 Subject: [PATCH] support CPU --- demo/inference_from_file.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/demo/inference_from_file.py b/demo/inference_from_file.py index 73fbce8..0c5b47e 100644 --- a/demo/inference_from_file.py +++ b/demo/inference_from_file.py @@ -246,23 +246,30 @@ def main(): processor = VibeVoiceProcessor.from_pretrained(args.model_path) # Load model + torch_dtype, device_map, attn_implementation = torch.bfloat16, 'cuda', 'flash_attention_2' + if args.device == 'cpu': + torch_dtype, device_map, attn_implementation = torch.float32, 'cpu', 'sdpa' + print(f"Using device: {args.device}, torch_dtype: {torch_dtype}, attn_implementation: {attn_implementation}") try: model = VibeVoiceForConditionalGenerationInference.from_pretrained( args.model_path, - torch_dtype=torch.bfloat16, - device_map='cuda', - attn_implementation='flash_attention_2' # flash_attention_2 is recommended + torch_dtype=torch_dtype, + device_map=device_map, + attn_implementation=attn_implementation ) except Exception as e: - print(f"[ERROR] : {type(e).__name__}: {e}") - print(traceback.format_exc()) - print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.") - model = VibeVoiceForConditionalGenerationInference.from_pretrained( - args.model_path, - torch_dtype=torch.bfloat16, - device_map='cuda', - attn_implementation='sdpa' - ) + if attn_implementation == 'flash_attention_2': + print(f"[ERROR] : {type(e).__name__}: {e}") + print(traceback.format_exc()) + print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.") + model = VibeVoiceForConditionalGenerationInference.from_pretrained( + args.model_path, + torch_dtype=torch_dtype, + device_map=device_map, + attn_implementation='sdpa' + ) + else: + raise e model.eval() model.set_ddpm_inference_steps(num_steps=10)