feat: use same config as cuda for xpu

This commit is contained in:
Ayo Ayco 2025-09-02 23:26:22 +02:00
parent 90e1cd1486
commit aeb26374df

View file

@ -264,6 +264,9 @@ def main():
elif args.device == "cuda":
load_dtype = torch.bfloat16
attn_impl_primary = "flash_attention_2"
elif args.device == "xpu": # test if xpu can handle same config as cuda
load_dtype = torch.bfloat16
attn_impl_primary = "flash_attention_2"
else: # cpu
load_dtype = torch.float32
attn_impl_primary = "sdpa"