feat: add Apple MPS device support (dtype + attention handling) to demos (#67)

* Enhance model loading with device support and error handling

Updated device handling for model loading and added support for MPS. Improved error handling and fallback mechanisms for attention implementations.

* Improve device handling and model loading logic

Updated device argument handling to support MPS and added validation for MPS availability. Enhanced model loading logic based on the selected device type.

* fallback only when flash_attention_2 and add some comments back

---------

Co-authored-by: YaoyaoChang <cyy574006791@qq.com>
This commit is contained in:
gregory-fanous 2025-09-01 04:16:01 -05:00 committed by GitHub
parent 3074f898ca
commit 6dd49cb720
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 130 additions and 44 deletions

View file

@ -47,30 +47,67 @@ class VibeVoiceDemo:
def load_model(self):
"""Load the VibeVoice model and processor."""
print(f"Loading processor & model from {self.model_path}")
# Normalize potential 'mpx'
if self.device.lower() == "mpx":
print("Note: device 'mpx' detected, treating it as 'mps'.")
self.device = "mps"
if self.device == "mps" and not torch.backends.mps.is_available():
print("Warning: MPS not available. Falling back to CPU.")
self.device = "cpu"
print(f"Using device: {self.device}")
# Load processor
self.processor = VibeVoiceProcessor.from_pretrained(
self.model_path,
)
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
# Decide dtype & attention
if self.device == "mps":
load_dtype = torch.float32
attn_impl_primary = "sdpa"
elif self.device == "cuda":
load_dtype = torch.bfloat16
attn_impl_primary = "flash_attention_2"
else:
load_dtype = torch.float32
attn_impl_primary = "sdpa"
print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
# Load model
try:
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=torch.bfloat16,
device_map='cuda',
attn_implementation='flash_attention_2' # flash_attention_2 is recommended
)
if self.device == "mps":
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
attn_implementation=attn_impl_primary,
device_map=None,
)
self.model.to("mps")
elif self.device == "cuda":
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
device_map="cuda",
attn_implementation=attn_impl_primary,
)
else:
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
device_map="cpu",
attn_implementation=attn_impl_primary,
)
except Exception as e:
print(f"[ERROR] : {type(e).__name__}: {e}")
print(traceback.format_exc())
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=torch.bfloat16,
device_map='cuda',
attn_implementation='sdpa'
)
if attn_impl_primary == 'flash_attention_2':
print(f"[ERROR] : {type(e).__name__}: {e}")
print(traceback.format_exc())
fallback_attn = "sdpa"
print(f"Falling back to attention implementation: {fallback_attn}")
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
device_map=(self.device if self.device in ("cuda", "cpu") else None),
attn_implementation=fallback_attn,
)
if self.device == "mps":
self.model.to("mps")
else:
raise e
self.model.eval()
# Use SDE solver by default
@ -238,6 +275,11 @@ class VibeVoiceDemo:
return_tensors="pt",
return_attention_mask=True,
)
# Move tensors to device
target_device = self.device if self.device in ("cuda", "mps") else "cpu"
for k, v in inputs.items():
if torch.is_tensor(v):
inputs[k] = v.to(target_device)
# Create audio streamer
audio_streamer = AudioStreamer(
@ -1133,8 +1175,8 @@ def parse_args():
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device for inference",
default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")),
help="Device for inference: cuda | mps | cpu",
)
parser.add_argument(
"--inference_steps",

View file

@ -167,8 +167,8 @@ def parse_args():
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device for tensor tests",
default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")),
help="Device for inference: cuda | mps | cpu",
)
parser.add_argument(
"--cfg_scale",
@ -182,6 +182,18 @@ def parse_args():
def main():
args = parse_args()
# Normalize potential 'mpx' typo to 'mps'
if args.device.lower() == "mpx":
print("Note: device 'mpx' detected, treating it as 'mps'.")
args.device = "mps"
# Validate mps availability if requested
if args.device == "mps" and not torch.backends.mps.is_available():
print("Warning: MPS not available. Falling back to CPU.")
args.device = "cpu"
print(f"Using device: {args.device}")
# Initialize voice mapper
voice_mapper = VoiceMapper()
@ -241,36 +253,62 @@ def main():
full_script = '\n'.join(scripts)
full_script = full_script.replace("", "'")
# Load processor
print(f"Loading processor & model from {args.model_path}")
processor = VibeVoiceProcessor.from_pretrained(args.model_path)
# Load model
torch_dtype, device_map, attn_implementation = torch.bfloat16, 'cuda', 'flash_attention_2'
if args.device == 'cpu':
torch_dtype, device_map, attn_implementation = torch.float32, 'cpu', 'sdpa'
print(f"Using device: {args.device}, torch_dtype: {torch_dtype}, attn_implementation: {attn_implementation}")
# Decide dtype & attention implementation
if args.device == "mps":
load_dtype = torch.float32 # MPS requires float32
attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS
elif args.device == "cuda":
load_dtype = torch.bfloat16
attn_impl_primary = "flash_attention_2"
else: # cpu
load_dtype = torch.float32
attn_impl_primary = "sdpa"
print(f"Using device: {args.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
# Load model with device-specific logic
try:
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=torch_dtype,
device_map=device_map,
attn_implementation=attn_implementation
)
if args.device == "mps":
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=load_dtype,
attn_implementation=attn_impl_primary,
device_map=None, # load then move
)
model.to("mps")
elif args.device == "cuda":
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=load_dtype,
device_map="cuda",
attn_implementation=attn_impl_primary,
)
else: # cpu
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=load_dtype,
device_map="cpu",
attn_implementation=attn_impl_primary,
)
except Exception as e:
if attn_implementation == 'flash_attention_2':
if attn_impl_primary == 'flash_attention_2':
print(f"[ERROR] : {type(e).__name__}: {e}")
print(traceback.format_exc())
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=torch_dtype,
device_map=device_map,
torch_dtype=load_dtype,
device_map=(args.device if args.device in ("cuda", "cpu") else None),
attn_implementation='sdpa'
)
if args.device == "mps":
model.to("mps")
else:
raise e
model.eval()
model.set_ddpm_inference_steps(num_steps=10)
@ -279,12 +317,19 @@ def main():
# Prepare inputs for the model
inputs = processor(
text=[full_script], # Wrap in list for batch processing
voice_samples=[voice_samples], # Wrap in list for batch processing
text=[full_script], # Wrap in list for batch processing
voice_samples=[voice_samples], # Wrap in list for batch processing
padding=True,
return_tensors="pt",
return_attention_mask=True,
)
# Move tensors to target device
target_device = args.device if args.device != "cpu" else "cpu"
for k, v in inputs.items():
if torch.is_tensor(v):
inputs[k] = v.to(target_device)
print(f"Starting generation with cfg_scale: {args.cfg_scale}")
# Generate audio
@ -294,7 +339,6 @@ def main():
max_new_tokens=None,
cfg_scale=args.cfg_scale,
tokenizer=processor.tokenizer,
# generation_config={'do_sample': False, 'temperature': 0.95, 'top_p': 0.95, 'top_k': 0},
generation_config={'do_sample': False},
verbose=True,
)
@ -323,13 +367,13 @@ def main():
print(f"Generated tokens: {generated_tokens}")
print(f"Total tokens: {output_tokens}")
# Save output
# Save output (processor handles device internally)
txt_filename = os.path.splitext(os.path.basename(args.txt_path))[0]
output_path = os.path.join(args.output_dir, f"{txt_filename}_generated.wav")
os.makedirs(args.output_dir, exist_ok=True)
processor.save_audio(
outputs.speech_outputs[0], # First (and only) batch item
outputs.speech_outputs[0], # First (and only) batch item
output_path=output_path,
)
print(f"Saved output to {output_path}")