From 1aa9c9597afd166b0427f28297a33a5e9c73c01a Mon Sep 17 00:00:00 2001 From: Qiyuan Gong Date: Fri, 7 Jun 2024 14:08:00 +0800 Subject: [PATCH] Avoid duplicate import in IPEX auto importer (#11227) * Add custom import to avoid ipex duplicate importing * Add scope limitation --- .../llm/src/ipex_llm/utils/ipex_importer.py | 45 +++++++++++++++++-- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/utils/ipex_importer.py b/python/llm/src/ipex_llm/utils/ipex_importer.py index b04c2674..0b60e48c 100644 --- a/python/llm/src/ipex_llm/utils/ipex_importer.py +++ b/python/llm/src/ipex_llm/utils/ipex_importer.py @@ -19,6 +19,43 @@ import logging import builtins import sys from ipex_llm.utils.common import log4Error +import inspect + +# Save the original __import__ function +original_import = builtins.__import__ +ipex_duplicate_import_error = "intel_extension_for_pytorch has already been automatically " + \ + "imported. Please avoid importing it again!" + + +def get_calling_package(): + """ + Return calling package name, e.g., ipex_llm.transformers + """ + # Get the current stack frame + frame = inspect.currentframe() + # Get the caller's frame + caller_frame = frame.f_back.f_back + # Get the caller's module + module = inspect.getmodule(caller_frame) + if module: + # Return the module's package name + return module.__package__ + return None + + +def custom_ipex_import(name, globals=None, locals=None, fromlist=(), level=0): + """ + Custom import function to avoid importing ipex again + """ + # check import calling pacage + calling_package = get_calling_package() + if calling_package is not None: + return original_import(name, globals, locals, fromlist, level) + # Only check ipex for main thread + if name == "ipex" or name == "intel_extension_for_pytorch": + log4Error.invalidInputError(False, + ipex_duplicate_import_error) + return original_import(name, globals, locals, fromlist, level) class IPEXImporter: @@ -61,12 +98,12 @@ class IPEXImporter: if self.is_xpu_version_installed(): # Check if user import ipex manually if 'ipex' in sys.modules or 'intel_extension_for_pytorch' in sys.modules: - logging.error("ipex_llm will automatically import intel_extension_for_pytorch.") log4Error.invalidInputError(False, - "Please import ipex_llm before importing \ - intel_extension_for_pytorch!") + ipex_duplicate_import_error) self.directly_import_ipex() self.ipex_version = ipex.__version__ + # Replace default importer + builtins.__import__ = custom_ipex_import logging.info("intel_extension_for_pytorch auto imported") def directly_import_ipex(self): @@ -95,6 +132,8 @@ class IPEXImporter: # try to import Intel Extension for PyTorch and get version self.directly_import_ipex() self.ipex_version = ipex.__version__ + # Replace default importer + builtins.__import__ = custom_ipex_import return self.ipex_version