Avoid duplicate import in IPEX auto importer (#11227)

* Add custom import to avoid ipex duplicate importing
* Add scope limitation
This commit is contained in:
Qiyuan Gong 2024-06-07 14:08:00 +08:00 committed by GitHub
parent 6f2684e5c9
commit 1aa9c9597a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -19,6 +19,43 @@ import logging
import builtins
import sys
from ipex_llm.utils.common import log4Error
import inspect
# Save the original __import__ function
original_import = builtins.__import__
ipex_duplicate_import_error = "intel_extension_for_pytorch has already been automatically " + \
"imported. Please avoid importing it again!"
def get_calling_package():
"""
Return calling package name, e.g., ipex_llm.transformers
"""
# Get the current stack frame
frame = inspect.currentframe()
# Get the caller's frame
caller_frame = frame.f_back.f_back
# Get the caller's module
module = inspect.getmodule(caller_frame)
if module:
# Return the module's package name
return module.__package__
return None
def custom_ipex_import(name, globals=None, locals=None, fromlist=(), level=0):
"""
Custom import function to avoid importing ipex again
"""
# check import calling pacage
calling_package = get_calling_package()
if calling_package is not None:
return original_import(name, globals, locals, fromlist, level)
# Only check ipex for main thread
if name == "ipex" or name == "intel_extension_for_pytorch":
log4Error.invalidInputError(False,
ipex_duplicate_import_error)
return original_import(name, globals, locals, fromlist, level)
class IPEXImporter:
@ -61,12 +98,12 @@ class IPEXImporter:
if self.is_xpu_version_installed():
# Check if user import ipex manually
if 'ipex' in sys.modules or 'intel_extension_for_pytorch' in sys.modules:
logging.error("ipex_llm will automatically import intel_extension_for_pytorch.")
log4Error.invalidInputError(False,
"Please import ipex_llm before importing \
intel_extension_for_pytorch!")
ipex_duplicate_import_error)
self.directly_import_ipex()
self.ipex_version = ipex.__version__
# Replace default importer
builtins.__import__ = custom_ipex_import
logging.info("intel_extension_for_pytorch auto imported")
def directly_import_ipex(self):
@ -95,6 +132,8 @@ class IPEXImporter:
# try to import Intel Extension for PyTorch and get version
self.directly_import_ipex()
self.ipex_version = ipex.__version__
# Replace default importer
builtins.__import__ = custom_ipex_import
return self.ipex_version