diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 3e2878b9..449d331a 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -301,6 +301,8 @@ def use_flash_attention(query, key, attention_mask=None): # ipex flash attention is only supported for xetla # may update this later return False + elif get_xpu_device_type(query) != "pvc": + return False if query.dtype not in [torch.float32, torch.float16]: # only use flash attention for fp32/fp16 input return False