From 151fcf37bb3737bdc2a1786b418691b630b53788 Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Fri, 7 Jun 2024 15:07:47 +0800 Subject: [PATCH] check devie name in use_flash_attention (#11263) --- python/llm/src/ipex_llm/transformers/models/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 3e2878b9..449d331a 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -301,6 +301,8 @@ def use_flash_attention(query, key, attention_mask=None): # ipex flash attention is only supported for xetla # may update this later return False + elif get_xpu_device_type(query) != "pvc": + return False if query.dtype not in [torch.float32, torch.float16]: # only use flash attention for fp32/fp16 input return False