support CPU

This commit is contained in:
YaoyaoChang 2025-09-01 16:32:50 +08:00
parent 086cd6aa6c
commit 3074f898ca

View file

@ -246,23 +246,30 @@ def main():
processor = VibeVoiceProcessor.from_pretrained(args.model_path) processor = VibeVoiceProcessor.from_pretrained(args.model_path)
# Load model # Load model
torch_dtype, device_map, attn_implementation = torch.bfloat16, 'cuda', 'flash_attention_2'
if args.device == 'cpu':
torch_dtype, device_map, attn_implementation = torch.float32, 'cpu', 'sdpa'
print(f"Using device: {args.device}, torch_dtype: {torch_dtype}, attn_implementation: {attn_implementation}")
try: try:
model = VibeVoiceForConditionalGenerationInference.from_pretrained( model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path, args.model_path,
torch_dtype=torch.bfloat16, torch_dtype=torch_dtype,
device_map='cuda', device_map=device_map,
attn_implementation='flash_attention_2' # flash_attention_2 is recommended attn_implementation=attn_implementation
) )
except Exception as e: except Exception as e:
print(f"[ERROR] : {type(e).__name__}: {e}") if attn_implementation == 'flash_attention_2':
print(traceback.format_exc()) print(f"[ERROR] : {type(e).__name__}: {e}")
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.") print(traceback.format_exc())
model = VibeVoiceForConditionalGenerationInference.from_pretrained( print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
args.model_path, model = VibeVoiceForConditionalGenerationInference.from_pretrained(
torch_dtype=torch.bfloat16, args.model_path,
device_map='cuda', torch_dtype=torch_dtype,
attn_implementation='sdpa' device_map=device_map,
) attn_implementation='sdpa'
)
else:
raise e
model.eval() model.eval()
model.set_ddpm_inference_steps(num_steps=10) model.set_ddpm_inference_steps(num_steps=10)