feat: load to XPU

This commit is contained in:
Ayo Ayco 2025-09-03 09:02:23 +02:00
parent aeb26374df
commit 1cb2a50ce5
2 changed files with 14 additions and 3 deletions

View file

@ -13,6 +13,7 @@ from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
torch.no_grad()
class VoiceMapper: class VoiceMapper:
"""Maps speaker names to voice file paths""" """Maps speaker names to voice file paths"""
@ -264,9 +265,9 @@ def main():
elif args.device == "cuda": elif args.device == "cuda":
load_dtype = torch.bfloat16 load_dtype = torch.bfloat16
attn_impl_primary = "flash_attention_2" attn_impl_primary = "flash_attention_2"
elif args.device == "xpu": # test if xpu can handle same config as cuda elif args.device == "xpu":
load_dtype = torch.bfloat16 load_dtype = torch.float32
attn_impl_primary = "flash_attention_2" attn_impl_primary = "sdpa"
else: # cpu else: # cpu
load_dtype = torch.float32 load_dtype = torch.float32
attn_impl_primary = "sdpa" attn_impl_primary = "sdpa"
@ -288,6 +289,15 @@ def main():
device_map="cuda", device_map="cuda",
attn_implementation=attn_impl_primary, attn_implementation=attn_impl_primary,
) )
elif args.device == "xpu":
torch.xpu.empty_cache()
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=load_dtype,
device_map=None,
attn_implementation=attn_impl_primary,
)
model.to('xpu')
else: # cpu else: # cpu
model = VibeVoiceForConditionalGenerationInference.from_pretrained( model = VibeVoiceForConditionalGenerationInference.from_pretrained(
args.model_path, args.model_path,

1
example.command Normal file
View file

@ -0,0 +1 @@
python demo/inference_from_file.py --model_path microsoft/VibeVoice-Large --txt_path demo/text_examples/2p_music.txt --speaker_names Alice Frank