feat: load to XPU
This commit is contained in:
parent
aeb26374df
commit
1cb2a50ce5
2 changed files with 14 additions and 3 deletions
|
@ -13,6 +13,7 @@ from transformers.utils import logging
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
torch.no_grad()
|
||||||
|
|
||||||
class VoiceMapper:
|
class VoiceMapper:
|
||||||
"""Maps speaker names to voice file paths"""
|
"""Maps speaker names to voice file paths"""
|
||||||
|
@ -264,9 +265,9 @@ def main():
|
||||||
elif args.device == "cuda":
|
elif args.device == "cuda":
|
||||||
load_dtype = torch.bfloat16
|
load_dtype = torch.bfloat16
|
||||||
attn_impl_primary = "flash_attention_2"
|
attn_impl_primary = "flash_attention_2"
|
||||||
elif args.device == "xpu": # test if xpu can handle same config as cuda
|
elif args.device == "xpu":
|
||||||
load_dtype = torch.bfloat16
|
load_dtype = torch.float32
|
||||||
attn_impl_primary = "flash_attention_2"
|
attn_impl_primary = "sdpa"
|
||||||
else: # cpu
|
else: # cpu
|
||||||
load_dtype = torch.float32
|
load_dtype = torch.float32
|
||||||
attn_impl_primary = "sdpa"
|
attn_impl_primary = "sdpa"
|
||||||
|
@ -288,6 +289,15 @@ def main():
|
||||||
device_map="cuda",
|
device_map="cuda",
|
||||||
attn_implementation=attn_impl_primary,
|
attn_implementation=attn_impl_primary,
|
||||||
)
|
)
|
||||||
|
elif args.device == "xpu":
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
|
args.model_path,
|
||||||
|
torch_dtype=load_dtype,
|
||||||
|
device_map=None,
|
||||||
|
attn_implementation=attn_impl_primary,
|
||||||
|
)
|
||||||
|
model.to('xpu')
|
||||||
else: # cpu
|
else: # cpu
|
||||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
args.model_path,
|
args.model_path,
|
||||||
|
|
1
example.command
Normal file
1
example.command
Normal file
|
@ -0,0 +1 @@
|
||||||
|
python demo/inference_from_file.py --model_path microsoft/VibeVoice-Large --txt_path demo/text_examples/2p_music.txt --speaker_names Alice Frank
|
Loading…
Reference in a new issue