Patch sdpa check function in specific module attributes table (#12285)
This commit is contained in:
parent
3700e81977
commit
546f455e8e
2 changed files with 11 additions and 3 deletions
|
|
@ -114,7 +114,7 @@ class _BaseAutoModelClass:
|
|||
|
||||
@classmethod
|
||||
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
||||
@patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
|
||||
@patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
|
||||
def from_pretrained(cls,
|
||||
*args,
|
||||
**kwargs):
|
||||
|
|
@ -542,7 +542,7 @@ class _BaseAutoModelClass:
|
|||
|
||||
@classmethod
|
||||
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
||||
@patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
|
||||
@patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
|
||||
def load_low_bit(cls,
|
||||
pretrained_model_name_or_path,
|
||||
*model_args,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
from typing import List
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
from ipex_llm.utils.ipex_importer import IPEXImporter
|
||||
|
||||
|
||||
def patch_flash_attn_import(filename: str) -> List[str]:
|
||||
|
|
@ -28,4 +29,11 @@ def patch_flash_attn_import(filename: str) -> List[str]:
|
|||
|
||||
|
||||
def patch_sdpa_available() -> bool:
|
||||
if IPEXImporter.is_xpu_version_installed():
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
from transformers.utils import is_torch_sdpa_available
|
||||
return is_torch_sdpa_available()
|
||||
except ImportError:
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in a new issue