diff --git a/demo/inference_from_file.py b/demo/inference_from_file.py index 5478568..2e4cbaf 100644 --- a/demo/inference_from_file.py +++ b/demo/inference_from_file.py @@ -264,6 +264,9 @@ def main(): elif args.device == "cuda": load_dtype = torch.bfloat16 attn_impl_primary = "flash_attention_2" + elif args.device == "xpu": # test if xpu can handle same config as cuda + load_dtype = torch.bfloat16 + attn_impl_primary = "flash_attention_2" else: # cpu load_dtype = torch.float32 attn_impl_primary = "sdpa"