check devie name in use_flash_attention (#11263)

This commit is contained in:
Xin Qiu 2024-06-07 15:07:47 +08:00 committed by GitHub
parent 2623944604
commit 151fcf37bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -301,6 +301,8 @@ def use_flash_attention(query, key, attention_mask=None):
# ipex flash attention is only supported for xetla
# may update this later
return False
elif get_xpu_device_type(query) != "pvc":
return False
if query.dtype not in [torch.float32, torch.float16]:
# only use flash attention for fp32/fp16 input
return False