feat: add Apple MPS device support (dtype + attention handling) to demos (#67)
* Enhance model loading with device support and error handling Updated device handling for model loading and added support for MPS. Improved error handling and fallback mechanisms for attention implementations. * Improve device handling and model loading logic Updated device argument handling to support MPS and added validation for MPS availability. Enhanced model loading logic based on the selected device type. * fallback only when flash_attention_2 and add some comments back --------- Co-authored-by: YaoyaoChang <cyy574006791@qq.com>
This commit is contained in:
parent
3074f898ca
commit
6dd49cb720
2 changed files with 130 additions and 44 deletions
|
@ -47,30 +47,67 @@ class VibeVoiceDemo:
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
"""Load the VibeVoice model and processor."""
|
"""Load the VibeVoice model and processor."""
|
||||||
print(f"Loading processor & model from {self.model_path}")
|
print(f"Loading processor & model from {self.model_path}")
|
||||||
|
# Normalize potential 'mpx'
|
||||||
|
if self.device.lower() == "mpx":
|
||||||
|
print("Note: device 'mpx' detected, treating it as 'mps'.")
|
||||||
|
self.device = "mps"
|
||||||
|
if self.device == "mps" and not torch.backends.mps.is_available():
|
||||||
|
print("Warning: MPS not available. Falling back to CPU.")
|
||||||
|
self.device = "cpu"
|
||||||
|
print(f"Using device: {self.device}")
|
||||||
# Load processor
|
# Load processor
|
||||||
self.processor = VibeVoiceProcessor.from_pretrained(
|
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
|
||||||
self.model_path,
|
# Decide dtype & attention
|
||||||
)
|
if self.device == "mps":
|
||||||
|
load_dtype = torch.float32
|
||||||
|
attn_impl_primary = "sdpa"
|
||||||
|
elif self.device == "cuda":
|
||||||
|
load_dtype = torch.bfloat16
|
||||||
|
attn_impl_primary = "flash_attention_2"
|
||||||
|
else:
|
||||||
|
load_dtype = torch.float32
|
||||||
|
attn_impl_primary = "sdpa"
|
||||||
|
print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
|
||||||
# Load model
|
# Load model
|
||||||
try:
|
try:
|
||||||
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
if self.device == "mps":
|
||||||
self.model_path,
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
torch_dtype=torch.bfloat16,
|
self.model_path,
|
||||||
device_map='cuda',
|
torch_dtype=load_dtype,
|
||||||
attn_implementation='flash_attention_2' # flash_attention_2 is recommended
|
attn_implementation=attn_impl_primary,
|
||||||
)
|
device_map=None,
|
||||||
|
)
|
||||||
|
self.model.to("mps")
|
||||||
|
elif self.device == "cuda":
|
||||||
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
torch_dtype=load_dtype,
|
||||||
|
device_map="cuda",
|
||||||
|
attn_implementation=attn_impl_primary,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
torch_dtype=load_dtype,
|
||||||
|
device_map="cpu",
|
||||||
|
attn_implementation=attn_impl_primary,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR] : {type(e).__name__}: {e}")
|
if attn_impl_primary == 'flash_attention_2':
|
||||||
print(traceback.format_exc())
|
print(f"[ERROR] : {type(e).__name__}: {e}")
|
||||||
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
|
print(traceback.format_exc())
|
||||||
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
fallback_attn = "sdpa"
|
||||||
self.model_path,
|
print(f"Falling back to attention implementation: {fallback_attn}")
|
||||||
torch_dtype=torch.bfloat16,
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
device_map='cuda',
|
self.model_path,
|
||||||
attn_implementation='sdpa'
|
torch_dtype=load_dtype,
|
||||||
)
|
device_map=(self.device if self.device in ("cuda", "cpu") else None),
|
||||||
|
attn_implementation=fallback_attn,
|
||||||
|
)
|
||||||
|
if self.device == "mps":
|
||||||
|
self.model.to("mps")
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
# Use SDE solver by default
|
# Use SDE solver by default
|
||||||
|
@ -238,6 +275,11 @@ class VibeVoiceDemo:
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
return_attention_mask=True,
|
return_attention_mask=True,
|
||||||
)
|
)
|
||||||
|
# Move tensors to device
|
||||||
|
target_device = self.device if self.device in ("cuda", "mps") else "cpu"
|
||||||
|
for k, v in inputs.items():
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
inputs[k] = v.to(target_device)
|
||||||
|
|
||||||
# Create audio streamer
|
# Create audio streamer
|
||||||
audio_streamer = AudioStreamer(
|
audio_streamer = AudioStreamer(
|
||||||
|
@ -1133,8 +1175,8 @@ def parse_args():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device",
|
"--device",
|
||||||
type=str,
|
type=str,
|
||||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")),
|
||||||
help="Device for inference",
|
help="Device for inference: cuda | mps | cpu",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--inference_steps",
|
"--inference_steps",
|
||||||
|
|
|
@ -167,8 +167,8 @@ def parse_args():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device",
|
"--device",
|
||||||
type=str,
|
type=str,
|
||||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")),
|
||||||
help="Device for tensor tests",
|
help="Device for inference: cuda | mps | cpu",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cfg_scale",
|
"--cfg_scale",
|
||||||
|
@ -182,6 +182,18 @@ def parse_args():
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
# Normalize potential 'mpx' typo to 'mps'
|
||||||
|
if args.device.lower() == "mpx":
|
||||||
|
print("Note: device 'mpx' detected, treating it as 'mps'.")
|
||||||
|
args.device = "mps"
|
||||||
|
|
||||||
|
# Validate mps availability if requested
|
||||||
|
if args.device == "mps" and not torch.backends.mps.is_available():
|
||||||
|
print("Warning: MPS not available. Falling back to CPU.")
|
||||||
|
args.device = "cpu"
|
||||||
|
|
||||||
|
print(f"Using device: {args.device}")
|
||||||
|
|
||||||
# Initialize voice mapper
|
# Initialize voice mapper
|
||||||
voice_mapper = VoiceMapper()
|
voice_mapper = VoiceMapper()
|
||||||
|
|
||||||
|
@ -241,36 +253,62 @@ def main():
|
||||||
full_script = '\n'.join(scripts)
|
full_script = '\n'.join(scripts)
|
||||||
full_script = full_script.replace("’", "'")
|
full_script = full_script.replace("’", "'")
|
||||||
|
|
||||||
# Load processor
|
|
||||||
print(f"Loading processor & model from {args.model_path}")
|
print(f"Loading processor & model from {args.model_path}")
|
||||||
processor = VibeVoiceProcessor.from_pretrained(args.model_path)
|
processor = VibeVoiceProcessor.from_pretrained(args.model_path)
|
||||||
|
|
||||||
# Load model
|
|
||||||
torch_dtype, device_map, attn_implementation = torch.bfloat16, 'cuda', 'flash_attention_2'
|
# Decide dtype & attention implementation
|
||||||
if args.device == 'cpu':
|
if args.device == "mps":
|
||||||
torch_dtype, device_map, attn_implementation = torch.float32, 'cpu', 'sdpa'
|
load_dtype = torch.float32 # MPS requires float32
|
||||||
print(f"Using device: {args.device}, torch_dtype: {torch_dtype}, attn_implementation: {attn_implementation}")
|
attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS
|
||||||
|
elif args.device == "cuda":
|
||||||
|
load_dtype = torch.bfloat16
|
||||||
|
attn_impl_primary = "flash_attention_2"
|
||||||
|
else: # cpu
|
||||||
|
load_dtype = torch.float32
|
||||||
|
attn_impl_primary = "sdpa"
|
||||||
|
print(f"Using device: {args.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
|
||||||
|
# Load model with device-specific logic
|
||||||
try:
|
try:
|
||||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
if args.device == "mps":
|
||||||
args.model_path,
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
torch_dtype=torch_dtype,
|
args.model_path,
|
||||||
device_map=device_map,
|
torch_dtype=load_dtype,
|
||||||
attn_implementation=attn_implementation
|
attn_implementation=attn_impl_primary,
|
||||||
)
|
device_map=None, # load then move
|
||||||
|
)
|
||||||
|
model.to("mps")
|
||||||
|
elif args.device == "cuda":
|
||||||
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
|
args.model_path,
|
||||||
|
torch_dtype=load_dtype,
|
||||||
|
device_map="cuda",
|
||||||
|
attn_implementation=attn_impl_primary,
|
||||||
|
)
|
||||||
|
else: # cpu
|
||||||
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
|
args.model_path,
|
||||||
|
torch_dtype=load_dtype,
|
||||||
|
device_map="cpu",
|
||||||
|
attn_implementation=attn_impl_primary,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if attn_implementation == 'flash_attention_2':
|
if attn_impl_primary == 'flash_attention_2':
|
||||||
print(f"[ERROR] : {type(e).__name__}: {e}")
|
print(f"[ERROR] : {type(e).__name__}: {e}")
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
|
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
|
||||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
args.model_path,
|
args.model_path,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=load_dtype,
|
||||||
device_map=device_map,
|
device_map=(args.device if args.device in ("cuda", "cpu") else None),
|
||||||
attn_implementation='sdpa'
|
attn_implementation='sdpa'
|
||||||
)
|
)
|
||||||
|
if args.device == "mps":
|
||||||
|
model.to("mps")
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
model.set_ddpm_inference_steps(num_steps=10)
|
model.set_ddpm_inference_steps(num_steps=10)
|
||||||
|
|
||||||
|
@ -279,12 +317,19 @@ def main():
|
||||||
|
|
||||||
# Prepare inputs for the model
|
# Prepare inputs for the model
|
||||||
inputs = processor(
|
inputs = processor(
|
||||||
text=[full_script], # Wrap in list for batch processing
|
text=[full_script], # Wrap in list for batch processing
|
||||||
voice_samples=[voice_samples], # Wrap in list for batch processing
|
voice_samples=[voice_samples], # Wrap in list for batch processing
|
||||||
padding=True,
|
padding=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
return_attention_mask=True,
|
return_attention_mask=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Move tensors to target device
|
||||||
|
target_device = args.device if args.device != "cpu" else "cpu"
|
||||||
|
for k, v in inputs.items():
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
inputs[k] = v.to(target_device)
|
||||||
|
|
||||||
print(f"Starting generation with cfg_scale: {args.cfg_scale}")
|
print(f"Starting generation with cfg_scale: {args.cfg_scale}")
|
||||||
|
|
||||||
# Generate audio
|
# Generate audio
|
||||||
|
@ -294,7 +339,6 @@ def main():
|
||||||
max_new_tokens=None,
|
max_new_tokens=None,
|
||||||
cfg_scale=args.cfg_scale,
|
cfg_scale=args.cfg_scale,
|
||||||
tokenizer=processor.tokenizer,
|
tokenizer=processor.tokenizer,
|
||||||
# generation_config={'do_sample': False, 'temperature': 0.95, 'top_p': 0.95, 'top_k': 0},
|
|
||||||
generation_config={'do_sample': False},
|
generation_config={'do_sample': False},
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
@ -323,13 +367,13 @@ def main():
|
||||||
print(f"Generated tokens: {generated_tokens}")
|
print(f"Generated tokens: {generated_tokens}")
|
||||||
print(f"Total tokens: {output_tokens}")
|
print(f"Total tokens: {output_tokens}")
|
||||||
|
|
||||||
# Save output
|
# Save output (processor handles device internally)
|
||||||
txt_filename = os.path.splitext(os.path.basename(args.txt_path))[0]
|
txt_filename = os.path.splitext(os.path.basename(args.txt_path))[0]
|
||||||
output_path = os.path.join(args.output_dir, f"{txt_filename}_generated.wav")
|
output_path = os.path.join(args.output_dir, f"{txt_filename}_generated.wav")
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
processor.save_audio(
|
processor.save_audio(
|
||||||
outputs.speech_outputs[0], # First (and only) batch item
|
outputs.speech_outputs[0], # First (and only) batch item
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
print(f"Saved output to {output_path}")
|
print(f"Saved output to {output_path}")
|
||||||
|
|
Loading…
Reference in a new issue