feat: use same config as cuda for xpu
This commit is contained in:
parent
90e1cd1486
commit
aeb26374df
1 changed files with 3 additions and 0 deletions
|
@ -264,6 +264,9 @@ def main():
|
|||
elif args.device == "cuda":
|
||||
load_dtype = torch.bfloat16
|
||||
attn_impl_primary = "flash_attention_2"
|
||||
elif args.device == "xpu": # test if xpu can handle same config as cuda
|
||||
load_dtype = torch.bfloat16
|
||||
attn_impl_primary = "flash_attention_2"
|
||||
else: # cpu
|
||||
load_dtype = torch.float32
|
||||
attn_impl_primary = "sdpa"
|
||||
|
|
Loading…
Reference in a new issue