Remove sdpa available patch (#12734)
This commit is contained in:
		
							parent
							
								
									c9b6c94a59
								
							
						
					
					
						commit
						dcca522618
					
				
					 2 changed files with 1 additions and 14 deletions
				
			
		| 
						 | 
					@ -51,7 +51,7 @@ from ipex_llm.transformers.gguf.api import load_gguf_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .utils import logger, load_state_dict
 | 
					from .utils import logger, load_state_dict
 | 
				
			||||||
from .utils import extract_local_archive_file, get_local_shard_files, load_imatrix_data
 | 
					from .utils import extract_local_archive_file, get_local_shard_files, load_imatrix_data
 | 
				
			||||||
from .patches import patch_flash_attn_import, patch_sdpa_available
 | 
					from .patches import patch_flash_attn_import
 | 
				
			||||||
 | 
					
 | 
				
			||||||
patched_training_mode = None
 | 
					patched_training_mode = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -108,7 +108,6 @@ class _BaseAutoModelClass:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
 | 
					    @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
 | 
				
			||||||
    @patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
 | 
					 | 
				
			||||||
    def from_pretrained(cls,
 | 
					    def from_pretrained(cls,
 | 
				
			||||||
                        *args,
 | 
					                        *args,
 | 
				
			||||||
                        **kwargs):
 | 
					                        **kwargs):
 | 
				
			||||||
| 
						 | 
					@ -531,7 +530,6 @@ class _BaseAutoModelClass:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
 | 
					    @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
 | 
				
			||||||
    @patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True)
 | 
					 | 
				
			||||||
    def load_low_bit(cls,
 | 
					    def load_low_bit(cls,
 | 
				
			||||||
                     pretrained_model_name_or_path,
 | 
					                     pretrained_model_name_or_path,
 | 
				
			||||||
                     *model_args,
 | 
					                     *model_args,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,14 +26,3 @@ def patch_flash_attn_import(filename: str) -> List[str]:
 | 
				
			||||||
    if "flash_attn" in imports:
 | 
					    if "flash_attn" in imports:
 | 
				
			||||||
        imports.remove("flash_attn")
 | 
					        imports.remove("flash_attn")
 | 
				
			||||||
    return imports
 | 
					    return imports
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def patch_sdpa_available() -> bool:
 | 
					 | 
				
			||||||
    if IPEXImporter.is_xpu_version_installed():
 | 
					 | 
				
			||||||
        return False
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            from transformers.utils import is_torch_sdpa_available
 | 
					 | 
				
			||||||
            return is_torch_sdpa_available()
 | 
					 | 
				
			||||||
        except ImportError:
 | 
					 | 
				
			||||||
            return False
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue