diff --git a/demo/gradio_demo.py b/demo/gradio_demo.py index de3410c..070dba2 100644 --- a/demo/gradio_demo.py +++ b/demo/gradio_demo.py @@ -47,30 +47,67 @@ class VibeVoiceDemo: def load_model(self): """Load the VibeVoice model and processor.""" print(f"Loading processor & model from {self.model_path}") - + # Normalize potential 'mpx' + if self.device.lower() == "mpx": + print("Note: device 'mpx' detected, treating it as 'mps'.") + self.device = "mps" + if self.device == "mps" and not torch.backends.mps.is_available(): + print("Warning: MPS not available. Falling back to CPU.") + self.device = "cpu" + print(f"Using device: {self.device}") # Load processor - self.processor = VibeVoiceProcessor.from_pretrained( - self.model_path, - ) - + self.processor = VibeVoiceProcessor.from_pretrained(self.model_path) + # Decide dtype & attention + if self.device == "mps": + load_dtype = torch.float32 + attn_impl_primary = "sdpa" + elif self.device == "cuda": + load_dtype = torch.bfloat16 + attn_impl_primary = "flash_attention_2" + else: + load_dtype = torch.float32 + attn_impl_primary = "sdpa" + print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}") # Load model try: - self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( - self.model_path, - torch_dtype=torch.bfloat16, - device_map='cuda', - attn_implementation='flash_attention_2' # flash_attention_2 is recommended - ) + if self.device == "mps": + self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( + self.model_path, + torch_dtype=load_dtype, + attn_implementation=attn_impl_primary, + device_map=None, + ) + self.model.to("mps") + elif self.device == "cuda": + self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( + self.model_path, + torch_dtype=load_dtype, + device_map="cuda", + attn_implementation=attn_impl_primary, + ) + else: + self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( + self.model_path, + torch_dtype=load_dtype, + device_map="cpu", + attn_implementation=attn_impl_primary, + ) except Exception as e: - print(f"[ERROR] : {type(e).__name__}: {e}") - print(traceback.format_exc()) - print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.") - self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( - self.model_path, - torch_dtype=torch.bfloat16, - device_map='cuda', - attn_implementation='sdpa' - ) + if attn_impl_primary == 'flash_attention_2': + print(f"[ERROR] : {type(e).__name__}: {e}") + print(traceback.format_exc()) + fallback_attn = "sdpa" + print(f"Falling back to attention implementation: {fallback_attn}") + self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( + self.model_path, + torch_dtype=load_dtype, + device_map=(self.device if self.device in ("cuda", "cpu") else None), + attn_implementation=fallback_attn, + ) + if self.device == "mps": + self.model.to("mps") + else: + raise e self.model.eval() # Use SDE solver by default @@ -238,6 +275,11 @@ class VibeVoiceDemo: return_tensors="pt", return_attention_mask=True, ) + # Move tensors to device + target_device = self.device if self.device in ("cuda", "mps") else "cpu" + for k, v in inputs.items(): + if torch.is_tensor(v): + inputs[k] = v.to(target_device) # Create audio streamer audio_streamer = AudioStreamer( @@ -1133,8 +1175,8 @@ def parse_args(): parser.add_argument( "--device", type=str, - default="cuda" if torch.cuda.is_available() else "cpu", - help="Device for inference", + default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")), + help="Device for inference: cuda | mps | cpu", ) parser.add_argument( "--inference_steps", diff --git a/demo/inference_from_file.py b/demo/inference_from_file.py index 0c5b47e..238ea70 100644 --- a/demo/inference_from_file.py +++ b/demo/inference_from_file.py @@ -167,8 +167,8 @@ def parse_args(): parser.add_argument( "--device", type=str, - default="cuda" if torch.cuda.is_available() else "cpu", - help="Device for tensor tests", + default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")), + help="Device for inference: cuda | mps | cpu", ) parser.add_argument( "--cfg_scale", @@ -182,6 +182,18 @@ def parse_args(): def main(): args = parse_args() + # Normalize potential 'mpx' typo to 'mps' + if args.device.lower() == "mpx": + print("Note: device 'mpx' detected, treating it as 'mps'.") + args.device = "mps" + + # Validate mps availability if requested + if args.device == "mps" and not torch.backends.mps.is_available(): + print("Warning: MPS not available. Falling back to CPU.") + args.device = "cpu" + + print(f"Using device: {args.device}") + # Initialize voice mapper voice_mapper = VoiceMapper() @@ -241,36 +253,62 @@ def main(): full_script = '\n'.join(scripts) full_script = full_script.replace("’", "'") - # Load processor print(f"Loading processor & model from {args.model_path}") processor = VibeVoiceProcessor.from_pretrained(args.model_path) - # Load model - torch_dtype, device_map, attn_implementation = torch.bfloat16, 'cuda', 'flash_attention_2' - if args.device == 'cpu': - torch_dtype, device_map, attn_implementation = torch.float32, 'cpu', 'sdpa' - print(f"Using device: {args.device}, torch_dtype: {torch_dtype}, attn_implementation: {attn_implementation}") + + # Decide dtype & attention implementation + if args.device == "mps": + load_dtype = torch.float32 # MPS requires float32 + attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS + elif args.device == "cuda": + load_dtype = torch.bfloat16 + attn_impl_primary = "flash_attention_2" + else: # cpu + load_dtype = torch.float32 + attn_impl_primary = "sdpa" + print(f"Using device: {args.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}") + # Load model with device-specific logic try: - model = VibeVoiceForConditionalGenerationInference.from_pretrained( - args.model_path, - torch_dtype=torch_dtype, - device_map=device_map, - attn_implementation=attn_implementation - ) + if args.device == "mps": + model = VibeVoiceForConditionalGenerationInference.from_pretrained( + args.model_path, + torch_dtype=load_dtype, + attn_implementation=attn_impl_primary, + device_map=None, # load then move + ) + model.to("mps") + elif args.device == "cuda": + model = VibeVoiceForConditionalGenerationInference.from_pretrained( + args.model_path, + torch_dtype=load_dtype, + device_map="cuda", + attn_implementation=attn_impl_primary, + ) + else: # cpu + model = VibeVoiceForConditionalGenerationInference.from_pretrained( + args.model_path, + torch_dtype=load_dtype, + device_map="cpu", + attn_implementation=attn_impl_primary, + ) except Exception as e: - if attn_implementation == 'flash_attention_2': + if attn_impl_primary == 'flash_attention_2': print(f"[ERROR] : {type(e).__name__}: {e}") print(traceback.format_exc()) print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.") model = VibeVoiceForConditionalGenerationInference.from_pretrained( args.model_path, - torch_dtype=torch_dtype, - device_map=device_map, + torch_dtype=load_dtype, + device_map=(args.device if args.device in ("cuda", "cpu") else None), attn_implementation='sdpa' ) + if args.device == "mps": + model.to("mps") else: raise e + model.eval() model.set_ddpm_inference_steps(num_steps=10) @@ -279,12 +317,19 @@ def main(): # Prepare inputs for the model inputs = processor( - text=[full_script], # Wrap in list for batch processing - voice_samples=[voice_samples], # Wrap in list for batch processing + text=[full_script], # Wrap in list for batch processing + voice_samples=[voice_samples], # Wrap in list for batch processing padding=True, return_tensors="pt", return_attention_mask=True, ) + + # Move tensors to target device + target_device = args.device if args.device != "cpu" else "cpu" + for k, v in inputs.items(): + if torch.is_tensor(v): + inputs[k] = v.to(target_device) + print(f"Starting generation with cfg_scale: {args.cfg_scale}") # Generate audio @@ -294,7 +339,6 @@ def main(): max_new_tokens=None, cfg_scale=args.cfg_scale, tokenizer=processor.tokenizer, - # generation_config={'do_sample': False, 'temperature': 0.95, 'top_p': 0.95, 'top_k': 0}, generation_config={'do_sample': False}, verbose=True, ) @@ -323,13 +367,13 @@ def main(): print(f"Generated tokens: {generated_tokens}") print(f"Total tokens: {output_tokens}") - # Save output + # Save output (processor handles device internally) txt_filename = os.path.splitext(os.path.basename(args.txt_path))[0] output_path = os.path.join(args.output_dir, f"{txt_filename}_generated.wav") os.makedirs(args.output_dir, exist_ok=True) processor.save_audio( - outputs.speech_outputs[0], # First (and only) batch item + outputs.speech_outputs[0], # First (and only) batch item output_path=output_path, ) print(f"Saved output to {output_path}")