feat: add Apple MPS device support (dtype + attention handling) to demos (#67)
* Enhance model loading with device support and error handling Updated device handling for model loading and added support for MPS. Improved error handling and fallback mechanisms for attention implementations. * Improve device handling and model loading logic Updated device argument handling to support MPS and added validation for MPS availability. Enhanced model loading logic based on the selected device type. * fallback only when flash_attention_2 and add some comments back --------- Co-authored-by: YaoyaoChang <cyy574006791@qq.com>
This commit is contained in:
parent
3074f898ca
commit
6dd49cb720
2 changed files with 130 additions and 44 deletions
|
@ -47,30 +47,67 @@ class VibeVoiceDemo:
|
|||
def load_model(self):
|
||||
"""Load the VibeVoice model and processor."""
|
||||
print(f"Loading processor & model from {self.model_path}")
|
||||
|
||||
# Normalize potential 'mpx'
|
||||
if self.device.lower() == "mpx":
|
||||
print("Note: device 'mpx' detected, treating it as 'mps'.")
|
||||
self.device = "mps"
|
||||
if self.device == "mps" and not torch.backends.mps.is_available():
|
||||
print("Warning: MPS not available. Falling back to CPU.")
|
||||
self.device = "cpu"
|
||||
print(f"Using device: {self.device}")
|
||||
# Load processor
|
||||
self.processor = VibeVoiceProcessor.from_pretrained(
|
||||
self.model_path,
|
||||
)
|
||||
|
||||
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
|
||||
# Decide dtype & attention
|
||||
if self.device == "mps":
|
||||
load_dtype = torch.float32
|
||||
attn_impl_primary = "sdpa"
|
||||
elif self.device == "cuda":
|
||||
load_dtype = torch.bfloat16
|
||||
attn_impl_primary = "flash_attention_2"
|
||||
else:
|
||||
load_dtype = torch.float32
|
||||
attn_impl_primary = "sdpa"
|
||||
print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
|
||||
# Load model
|
||||
try:
|
||||
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map='cuda',
|
||||
attn_implementation='flash_attention_2' # flash_attention_2 is recommended
|
||||
)
|
||||
if self.device == "mps":
|
||||
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=load_dtype,
|
||||
attn_implementation=attn_impl_primary,
|
||||
device_map=None,
|
||||
)
|
||||
self.model.to("mps")
|
||||
elif self.device == "cuda":
|
||||
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=load_dtype,
|
||||
device_map="cuda",
|
||||
attn_implementation=attn_impl_primary,
|
||||
)
|
||||
else:
|
||||
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=load_dtype,
|
||||
device_map="cpu",
|
||||
attn_implementation=attn_impl_primary,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] : {type(e).__name__}: {e}")
|
||||
print(traceback.format_exc())
|
||||
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
|
||||
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map='cuda',
|
||||
attn_implementation='sdpa'
|
||||
)
|
||||
if attn_impl_primary == 'flash_attention_2':
|
||||
print(f"[ERROR] : {type(e).__name__}: {e}")
|
||||
print(traceback.format_exc())
|
||||
fallback_attn = "sdpa"
|
||||
print(f"Falling back to attention implementation: {fallback_attn}")
|
||||
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=load_dtype,
|
||||
device_map=(self.device if self.device in ("cuda", "cpu") else None),
|
||||
attn_implementation=fallback_attn,
|
||||
)
|
||||
if self.device == "mps":
|
||||
self.model.to("mps")
|
||||
else:
|
||||
raise e
|
||||
self.model.eval()
|
||||
|
||||
# Use SDE solver by default
|
||||
|
@ -238,6 +275,11 @@ class VibeVoiceDemo:
|
|||
return_tensors="pt",
|
||||
return_attention_mask=True,
|
||||
)
|
||||
# Move tensors to device
|
||||
target_device = self.device if self.device in ("cuda", "mps") else "cpu"
|
||||
for k, v in inputs.items():
|
||||
if torch.is_tensor(v):
|
||||
inputs[k] = v.to(target_device)
|
||||
|
||||
# Create audio streamer
|
||||
audio_streamer = AudioStreamer(
|
||||
|
@ -1133,8 +1175,8 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Device for inference",
|
||||
default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")),
|
||||
help="Device for inference: cuda | mps | cpu",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--inference_steps",
|
||||
|
|
|
@ -167,8 +167,8 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Device for tensor tests",
|
||||
default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")),
|
||||
help="Device for inference: cuda | mps | cpu",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cfg_scale",
|
||||
|
@ -182,6 +182,18 @@ def parse_args():
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Normalize potential 'mpx' typo to 'mps'
|
||||
if args.device.lower() == "mpx":
|
||||
print("Note: device 'mpx' detected, treating it as 'mps'.")
|
||||
args.device = "mps"
|
||||
|
||||
# Validate mps availability if requested
|
||||
if args.device == "mps" and not torch.backends.mps.is_available():
|
||||
print("Warning: MPS not available. Falling back to CPU.")
|
||||
args.device = "cpu"
|
||||
|
||||
print(f"Using device: {args.device}")
|
||||
|
||||
# Initialize voice mapper
|
||||
voice_mapper = VoiceMapper()
|
||||
|
||||
|
@ -241,36 +253,62 @@ def main():
|
|||
full_script = '\n'.join(scripts)
|
||||
full_script = full_script.replace("’", "'")
|
||||
|
||||
# Load processor
|
||||
print(f"Loading processor & model from {args.model_path}")
|
||||
processor = VibeVoiceProcessor.from_pretrained(args.model_path)
|
||||
|
||||
# Load model
|
||||
torch_dtype, device_map, attn_implementation = torch.bfloat16, 'cuda', 'flash_attention_2'
|
||||
if args.device == 'cpu':
|
||||
torch_dtype, device_map, attn_implementation = torch.float32, 'cpu', 'sdpa'
|
||||
print(f"Using device: {args.device}, torch_dtype: {torch_dtype}, attn_implementation: {attn_implementation}")
|
||||
|
||||
# Decide dtype & attention implementation
|
||||
if args.device == "mps":
|
||||
load_dtype = torch.float32 # MPS requires float32
|
||||
attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS
|
||||
elif args.device == "cuda":
|
||||
load_dtype = torch.bfloat16
|
||||
attn_impl_primary = "flash_attention_2"
|
||||
else: # cpu
|
||||
load_dtype = torch.float32
|
||||
attn_impl_primary = "sdpa"
|
||||
print(f"Using device: {args.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
|
||||
# Load model with device-specific logic
|
||||
try:
|
||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
attn_implementation=attn_implementation
|
||||
)
|
||||
if args.device == "mps":
|
||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=load_dtype,
|
||||
attn_implementation=attn_impl_primary,
|
||||
device_map=None, # load then move
|
||||
)
|
||||
model.to("mps")
|
||||
elif args.device == "cuda":
|
||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=load_dtype,
|
||||
device_map="cuda",
|
||||
attn_implementation=attn_impl_primary,
|
||||
)
|
||||
else: # cpu
|
||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=load_dtype,
|
||||
device_map="cpu",
|
||||
attn_implementation=attn_impl_primary,
|
||||
)
|
||||
except Exception as e:
|
||||
if attn_implementation == 'flash_attention_2':
|
||||
if attn_impl_primary == 'flash_attention_2':
|
||||
print(f"[ERROR] : {type(e).__name__}: {e}")
|
||||
print(traceback.format_exc())
|
||||
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
|
||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
torch_dtype=load_dtype,
|
||||
device_map=(args.device if args.device in ("cuda", "cpu") else None),
|
||||
attn_implementation='sdpa'
|
||||
)
|
||||
if args.device == "mps":
|
||||
model.to("mps")
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
model.eval()
|
||||
model.set_ddpm_inference_steps(num_steps=10)
|
||||
|
||||
|
@ -279,12 +317,19 @@ def main():
|
|||
|
||||
# Prepare inputs for the model
|
||||
inputs = processor(
|
||||
text=[full_script], # Wrap in list for batch processing
|
||||
voice_samples=[voice_samples], # Wrap in list for batch processing
|
||||
text=[full_script], # Wrap in list for batch processing
|
||||
voice_samples=[voice_samples], # Wrap in list for batch processing
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
return_attention_mask=True,
|
||||
)
|
||||
|
||||
# Move tensors to target device
|
||||
target_device = args.device if args.device != "cpu" else "cpu"
|
||||
for k, v in inputs.items():
|
||||
if torch.is_tensor(v):
|
||||
inputs[k] = v.to(target_device)
|
||||
|
||||
print(f"Starting generation with cfg_scale: {args.cfg_scale}")
|
||||
|
||||
# Generate audio
|
||||
|
@ -294,7 +339,6 @@ def main():
|
|||
max_new_tokens=None,
|
||||
cfg_scale=args.cfg_scale,
|
||||
tokenizer=processor.tokenizer,
|
||||
# generation_config={'do_sample': False, 'temperature': 0.95, 'top_p': 0.95, 'top_k': 0},
|
||||
generation_config={'do_sample': False},
|
||||
verbose=True,
|
||||
)
|
||||
|
@ -323,13 +367,13 @@ def main():
|
|||
print(f"Generated tokens: {generated_tokens}")
|
||||
print(f"Total tokens: {output_tokens}")
|
||||
|
||||
# Save output
|
||||
# Save output (processor handles device internally)
|
||||
txt_filename = os.path.splitext(os.path.basename(args.txt_path))[0]
|
||||
output_path = os.path.join(args.output_dir, f"{txt_filename}_generated.wav")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
processor.save_audio(
|
||||
outputs.speech_outputs[0], # First (and only) batch item
|
||||
outputs.speech_outputs[0], # First (and only) batch item
|
||||
output_path=output_path,
|
||||
)
|
||||
print(f"Saved output to {output_path}")
|
||||
|
|
Loading…
Reference in a new issue