Patch sdpa check function in specific module attributes table (#12285)

This commit is contained in:
Zhao Changmin 2024-10-29 18:41:09 +08:00 committed by GitHub
parent 3700e81977
commit 546f455e8e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 11 additions and 3 deletions

View file

@ -114,7 +114,7 @@ class _BaseAutoModelClass:
@classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
@patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
@patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
def from_pretrained(cls,
*args,
**kwargs):
@ -542,7 +542,7 @@ class _BaseAutoModelClass:
@classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
@patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
@patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
def load_low_bit(cls,
pretrained_model_name_or_path,
*model_args,

View file

@ -17,6 +17,7 @@
from typing import List
from transformers.dynamic_module_utils import get_imports
from ipex_llm.utils.ipex_importer import IPEXImporter
def patch_flash_attn_import(filename: str) -> List[str]:
@ -28,4 +29,11 @@ def patch_flash_attn_import(filename: str) -> List[str]:
def patch_sdpa_available() -> bool:
if IPEXImporter.is_xpu_version_installed():
return False
else:
try:
from transformers.utils import is_torch_sdpa_available
return is_torch_sdpa_available()
except ImportError:
return False