Remove sdpa available patch (#12734)
This commit is contained in:
parent
c9b6c94a59
commit
dcca522618
2 changed files with 1 additions and 14 deletions
|
|
@ -51,7 +51,7 @@ from ipex_llm.transformers.gguf.api import load_gguf_model
|
||||||
|
|
||||||
from .utils import logger, load_state_dict
|
from .utils import logger, load_state_dict
|
||||||
from .utils import extract_local_archive_file, get_local_shard_files, load_imatrix_data
|
from .utils import extract_local_archive_file, get_local_shard_files, load_imatrix_data
|
||||||
from .patches import patch_flash_attn_import, patch_sdpa_available
|
from .patches import patch_flash_attn_import
|
||||||
|
|
||||||
patched_training_mode = None
|
patched_training_mode = None
|
||||||
|
|
||||||
|
|
@ -108,7 +108,6 @@ class _BaseAutoModelClass:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
||||||
@patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
|
|
||||||
def from_pretrained(cls,
|
def from_pretrained(cls,
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
@ -531,7 +530,6 @@ class _BaseAutoModelClass:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
||||||
@patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
|
|
||||||
def load_low_bit(cls,
|
def load_low_bit(cls,
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
*model_args,
|
*model_args,
|
||||||
|
|
|
||||||
|
|
@ -26,14 +26,3 @@ def patch_flash_attn_import(filename: str) -> List[str]:
|
||||||
if "flash_attn" in imports:
|
if "flash_attn" in imports:
|
||||||
imports.remove("flash_attn")
|
imports.remove("flash_attn")
|
||||||
return imports
|
return imports
|
||||||
|
|
||||||
|
|
||||||
def patch_sdpa_available() -> bool:
|
|
||||||
if IPEXImporter.is_xpu_version_installed():
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from transformers.utils import is_torch_sdpa_available
|
|
||||||
return is_torch_sdpa_available()
|
|
||||||
except ImportError:
|
|
||||||
return False
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue