From 1cb2a50ce5954d5871e2556f6a97a2be81cdcf9c Mon Sep 17 00:00:00 2001 From: Ayo Date: Wed, 3 Sep 2025 09:02:23 +0200 Subject: [PATCH] feat: load to XPU --- demo/inference_from_file.py | 16 +++++++++++++--- example.command | 1 + 2 files changed, 14 insertions(+), 3 deletions(-) create mode 100644 example.command diff --git a/demo/inference_from_file.py b/demo/inference_from_file.py index 2e4cbaf..e6e5612 100644 --- a/demo/inference_from_file.py +++ b/demo/inference_from_file.py @@ -13,6 +13,7 @@ from transformers.utils import logging logging.set_verbosity_info() logger = logging.get_logger(__name__) +torch.no_grad() class VoiceMapper: """Maps speaker names to voice file paths""" @@ -264,9 +265,9 @@ def main(): elif args.device == "cuda": load_dtype = torch.bfloat16 attn_impl_primary = "flash_attention_2" - elif args.device == "xpu": # test if xpu can handle same config as cuda - load_dtype = torch.bfloat16 - attn_impl_primary = "flash_attention_2" + elif args.device == "xpu": + load_dtype = torch.float32 + attn_impl_primary = "sdpa" else: # cpu load_dtype = torch.float32 attn_impl_primary = "sdpa" @@ -288,6 +289,15 @@ def main(): device_map="cuda", attn_implementation=attn_impl_primary, ) + elif args.device == "xpu": + torch.xpu.empty_cache() + model = VibeVoiceForConditionalGenerationInference.from_pretrained( + args.model_path, + torch_dtype=load_dtype, + device_map=None, + attn_implementation=attn_impl_primary, + ) + model.to('xpu') else: # cpu model = VibeVoiceForConditionalGenerationInference.from_pretrained( args.model_path, diff --git a/example.command b/example.command new file mode 100644 index 0000000..805bcbb --- /dev/null +++ b/example.command @@ -0,0 +1 @@ +python demo/inference_from_file.py --model_path microsoft/VibeVoice-Large --txt_path demo/text_examples/2p_music.txt --speaker_names Alice Frank