check devie name in use_flash_attention (#11263)
This commit is contained in:
parent
2623944604
commit
151fcf37bb
1 changed files with 2 additions and 0 deletions
|
|
@ -301,6 +301,8 @@ def use_flash_attention(query, key, attention_mask=None):
|
||||||
# ipex flash attention is only supported for xetla
|
# ipex flash attention is only supported for xetla
|
||||||
# may update this later
|
# may update this later
|
||||||
return False
|
return False
|
||||||
|
elif get_xpu_device_type(query) != "pvc":
|
||||||
|
return False
|
||||||
if query.dtype not in [torch.float32, torch.float16]:
|
if query.dtype not in [torch.float32, torch.float16]:
|
||||||
# only use flash attention for fp32/fp16 input
|
# only use flash attention for fp32/fp16 input
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue