From 546f455e8eb942df16edb562c83eae3f1e5081d1 Mon Sep 17 00:00:00 2001 From: Zhao Changmin Date: Tue, 29 Oct 2024 18:41:09 +0800 Subject: [PATCH] Patch sdpa check function in specific module attributes table (#12285) --- python/llm/src/ipex_llm/transformers/model.py | 4 ++-- python/llm/src/ipex_llm/transformers/patches.py | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index f81ee840..3e68d8ac 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -114,7 +114,7 @@ class _BaseAutoModelClass: @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) - @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True) + @patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True) def from_pretrained(cls, *args, **kwargs): @@ -542,7 +542,7 @@ class _BaseAutoModelClass: @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) - @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True) + @patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True) def load_low_bit(cls, pretrained_model_name_or_path, *model_args, diff --git a/python/llm/src/ipex_llm/transformers/patches.py b/python/llm/src/ipex_llm/transformers/patches.py index f115ffa5..743232c5 100644 --- a/python/llm/src/ipex_llm/transformers/patches.py +++ b/python/llm/src/ipex_llm/transformers/patches.py @@ -17,6 +17,7 @@ from typing import List from transformers.dynamic_module_utils import get_imports +from ipex_llm.utils.ipex_importer import IPEXImporter def patch_flash_attn_import(filename: str) -> List[str]: @@ -28,4 +29,11 @@ def patch_flash_attn_import(filename: str) -> List[str]: def patch_sdpa_available() -> bool: - return False + if IPEXImporter.is_xpu_version_installed(): + return False + else: + try: + from transformers.utils import is_torch_sdpa_available + return is_torch_sdpa_available() + except ImportError: + return False