check devie name in use_flash_attention (#11263)
This commit is contained in:
parent
2623944604
commit
151fcf37bb
1 changed files with 2 additions and 0 deletions
|
|
@ -301,6 +301,8 @@ def use_flash_attention(query, key, attention_mask=None):
|
|||
# ipex flash attention is only supported for xetla
|
||||
# may update this later
|
||||
return False
|
||||
elif get_xpu_device_type(query) != "pvc":
|
||||
return False
|
||||
if query.dtype not in [torch.float32, torch.float16]:
|
||||
# only use flash attention for fp32/fp16 input
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in a new issue