feat: use same config as cuda for xpu
This commit is contained in:
parent
90e1cd1486
commit
aeb26374df
1 changed files with 3 additions and 0 deletions
|
@ -264,6 +264,9 @@ def main():
|
||||||
elif args.device == "cuda":
|
elif args.device == "cuda":
|
||||||
load_dtype = torch.bfloat16
|
load_dtype = torch.bfloat16
|
||||||
attn_impl_primary = "flash_attention_2"
|
attn_impl_primary = "flash_attention_2"
|
||||||
|
elif args.device == "xpu": # test if xpu can handle same config as cuda
|
||||||
|
load_dtype = torch.bfloat16
|
||||||
|
attn_impl_primary = "flash_attention_2"
|
||||||
else: # cpu
|
else: # cpu
|
||||||
load_dtype = torch.float32
|
load_dtype = torch.float32
|
||||||
attn_impl_primary = "sdpa"
|
attn_impl_primary = "sdpa"
|
||||||
|
|
Loading…
Reference in a new issue