support CPU
This commit is contained in:
parent
086cd6aa6c
commit
3074f898ca
1 changed files with 19 additions and 12 deletions
|
@ -246,23 +246,30 @@ def main():
|
|||
processor = VibeVoiceProcessor.from_pretrained(args.model_path)
|
||||
|
||||
# Load model
|
||||
torch_dtype, device_map, attn_implementation = torch.bfloat16, 'cuda', 'flash_attention_2'
|
||||
if args.device == 'cpu':
|
||||
torch_dtype, device_map, attn_implementation = torch.float32, 'cpu', 'sdpa'
|
||||
print(f"Using device: {args.device}, torch_dtype: {torch_dtype}, attn_implementation: {attn_implementation}")
|
||||
try:
|
||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map='cuda',
|
||||
attn_implementation='flash_attention_2' # flash_attention_2 is recommended
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
attn_implementation=attn_implementation
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] : {type(e).__name__}: {e}")
|
||||
print(traceback.format_exc())
|
||||
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
|
||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map='cuda',
|
||||
attn_implementation='sdpa'
|
||||
)
|
||||
if attn_implementation == 'flash_attention_2':
|
||||
print(f"[ERROR] : {type(e).__name__}: {e}")
|
||||
print(traceback.format_exc())
|
||||
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
|
||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
args.model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
attn_implementation='sdpa'
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
model.eval()
|
||||
model.set_ddpm_inference_steps(num_steps=10)
|
||||
|
|
Loading…
Reference in a new issue