diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index 5459056b..788a2edb 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -51,7 +51,7 @@ from ipex_llm.transformers.gguf.api import load_gguf_model from .utils import logger, load_state_dict from .utils import extract_local_archive_file, get_local_shard_files, load_imatrix_data -from .patches import patch_flash_attn_import, patch_sdpa_available +from .patches import patch_flash_attn_import patched_training_mode = None @@ -108,7 +108,6 @@ class _BaseAutoModelClass: @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) - @patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True) def from_pretrained(cls, *args, **kwargs): @@ -531,7 +530,6 @@ class _BaseAutoModelClass: @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) - @patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True) def load_low_bit(cls, pretrained_model_name_or_path, *model_args, diff --git a/python/llm/src/ipex_llm/transformers/patches.py b/python/llm/src/ipex_llm/transformers/patches.py index 743232c5..9d00b15a 100644 --- a/python/llm/src/ipex_llm/transformers/patches.py +++ b/python/llm/src/ipex_llm/transformers/patches.py @@ -26,14 +26,3 @@ def patch_flash_attn_import(filename: str) -> List[str]: if "flash_attn" in imports: imports.remove("flash_attn") return imports - - -def patch_sdpa_available() -> bool: - if IPEXImporter.is_xpu_version_installed(): - return False - else: - try: - from transformers.utils import is_torch_sdpa_available - return is_torch_sdpa_available() - except ImportError: - return False