Avoid duplicate import in IPEX auto importer (#11227)
* Add custom import to avoid ipex duplicate importing * Add scope limitation
This commit is contained in:
parent
6f2684e5c9
commit
1aa9c9597a
1 changed files with 42 additions and 3 deletions
|
|
@ -19,6 +19,43 @@ import logging
|
||||||
import builtins
|
import builtins
|
||||||
import sys
|
import sys
|
||||||
from ipex_llm.utils.common import log4Error
|
from ipex_llm.utils.common import log4Error
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
# Save the original __import__ function
|
||||||
|
original_import = builtins.__import__
|
||||||
|
ipex_duplicate_import_error = "intel_extension_for_pytorch has already been automatically " + \
|
||||||
|
"imported. Please avoid importing it again!"
|
||||||
|
|
||||||
|
|
||||||
|
def get_calling_package():
|
||||||
|
"""
|
||||||
|
Return calling package name, e.g., ipex_llm.transformers
|
||||||
|
"""
|
||||||
|
# Get the current stack frame
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
# Get the caller's frame
|
||||||
|
caller_frame = frame.f_back.f_back
|
||||||
|
# Get the caller's module
|
||||||
|
module = inspect.getmodule(caller_frame)
|
||||||
|
if module:
|
||||||
|
# Return the module's package name
|
||||||
|
return module.__package__
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def custom_ipex_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||||
|
"""
|
||||||
|
Custom import function to avoid importing ipex again
|
||||||
|
"""
|
||||||
|
# check import calling pacage
|
||||||
|
calling_package = get_calling_package()
|
||||||
|
if calling_package is not None:
|
||||||
|
return original_import(name, globals, locals, fromlist, level)
|
||||||
|
# Only check ipex for main thread
|
||||||
|
if name == "ipex" or name == "intel_extension_for_pytorch":
|
||||||
|
log4Error.invalidInputError(False,
|
||||||
|
ipex_duplicate_import_error)
|
||||||
|
return original_import(name, globals, locals, fromlist, level)
|
||||||
|
|
||||||
|
|
||||||
class IPEXImporter:
|
class IPEXImporter:
|
||||||
|
|
@ -61,12 +98,12 @@ class IPEXImporter:
|
||||||
if self.is_xpu_version_installed():
|
if self.is_xpu_version_installed():
|
||||||
# Check if user import ipex manually
|
# Check if user import ipex manually
|
||||||
if 'ipex' in sys.modules or 'intel_extension_for_pytorch' in sys.modules:
|
if 'ipex' in sys.modules or 'intel_extension_for_pytorch' in sys.modules:
|
||||||
logging.error("ipex_llm will automatically import intel_extension_for_pytorch.")
|
|
||||||
log4Error.invalidInputError(False,
|
log4Error.invalidInputError(False,
|
||||||
"Please import ipex_llm before importing \
|
ipex_duplicate_import_error)
|
||||||
intel_extension_for_pytorch!")
|
|
||||||
self.directly_import_ipex()
|
self.directly_import_ipex()
|
||||||
self.ipex_version = ipex.__version__
|
self.ipex_version = ipex.__version__
|
||||||
|
# Replace default importer
|
||||||
|
builtins.__import__ = custom_ipex_import
|
||||||
logging.info("intel_extension_for_pytorch auto imported")
|
logging.info("intel_extension_for_pytorch auto imported")
|
||||||
|
|
||||||
def directly_import_ipex(self):
|
def directly_import_ipex(self):
|
||||||
|
|
@ -95,6 +132,8 @@ class IPEXImporter:
|
||||||
# try to import Intel Extension for PyTorch and get version
|
# try to import Intel Extension for PyTorch and get version
|
||||||
self.directly_import_ipex()
|
self.directly_import_ipex()
|
||||||
self.ipex_version = ipex.__version__
|
self.ipex_version = ipex.__version__
|
||||||
|
# Replace default importer
|
||||||
|
builtins.__import__ = custom_ipex_import
|
||||||
return self.ipex_version
|
return self.ipex_version
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue