IPEX Duplicate importer V2 (#11310)
* Add gguf support. * Avoid error when import ipex-llm for multiple times. * Add check to avoid duplicate replace and revert. * Add calling from check to avoid raising exceptions in the submodule. * Add BIGDL_CHECK_DUPLICATE_IMPORT for controlling duplicate checker. Default is true.
This commit is contained in:
		
							parent
							
								
									271d82a4fc
								
							
						
					
					
						commit
						1eb884a249
					
				
					 4 changed files with 74 additions and 5 deletions
				
			
		| 
						 | 
					@ -26,13 +26,15 @@ from .llm_patching import llm_patch, llm_unpatch
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
import types
 | 
					import types
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Default is false, set to true to auto importing Intel Extension for PyTorch.
 | 
					# Default is True, set to False to disable auto importing Intel Extension for PyTorch.
 | 
				
			||||||
USE_NPU = os.getenv("BIGDL_USE_NPU", 'False').lower() in ('true', '1', 't')
 | 
					USE_NPU = os.getenv("BIGDL_USE_NPU", 'False').lower() in ('true', '1', 't')
 | 
				
			||||||
BIGDL_IMPORT_IPEX = os.getenv("BIGDL_IMPORT_IPEX", 'True').lower() in ('true', '1', 't')
 | 
					BIGDL_IMPORT_IPEX = os.getenv("BIGDL_IMPORT_IPEX", 'True').lower() in ('true', '1', 't')
 | 
				
			||||||
BIGDL_IMPORT_IPEX = not USE_NPU and BIGDL_IMPORT_IPEX
 | 
					BIGDL_IMPORT_IPEX = not USE_NPU and BIGDL_IMPORT_IPEX
 | 
				
			||||||
if BIGDL_IMPORT_IPEX:
 | 
					if BIGDL_IMPORT_IPEX:
 | 
				
			||||||
    # Import Intel Extension for PyTorch as ipex if XPU version is installed
 | 
					    # Import Intel Extension for PyTorch as ipex if XPU version is installed
 | 
				
			||||||
    from .utils.ipex_importer import ipex_importer
 | 
					    from .utils.ipex_importer import ipex_importer
 | 
				
			||||||
 | 
					    # Avoid duplicate import
 | 
				
			||||||
 | 
					    if ipex_importer.get_ipex_version() is None:
 | 
				
			||||||
        ipex_importer.import_ipex()
 | 
					        ipex_importer.import_ipex()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Default is true, set to true to auto patching bigdl-llm to ipex_llm.
 | 
					# Default is true, set to true to auto patching bigdl-llm to ipex_llm.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -773,6 +773,9 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
				
			||||||
                    f"{list(gguf_mixed_qtype.keys())[index]} "
 | 
					                    f"{list(gguf_mixed_qtype.keys())[index]} "
 | 
				
			||||||
                    f"format......")
 | 
					                    f"format......")
 | 
				
			||||||
    modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
 | 
					    modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
 | 
				
			||||||
 | 
					    # Disable ipex duplicate import checker
 | 
				
			||||||
 | 
					    from ipex_llm.utils.ipex_importer import revert_import
 | 
				
			||||||
 | 
					    revert_import()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # using ipex_llm optimizer before changing to bigdl linear
 | 
					    # using ipex_llm optimizer before changing to bigdl linear
 | 
				
			||||||
    _enable_ipex = get_enable_ipex()
 | 
					    _enable_ipex = get_enable_ipex()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -30,6 +30,9 @@ qtype_map = {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def load_gguf_model(fpath: str, dtype: torch.dtype = torch.float, low_bit: str = "sym_int4"):
 | 
					def load_gguf_model(fpath: str, dtype: torch.dtype = torch.float, low_bit: str = "sym_int4"):
 | 
				
			||||||
    from .gguf import GGUFFileLoader
 | 
					    from .gguf import GGUFFileLoader
 | 
				
			||||||
 | 
					    # Disable ipex duplicate import checker
 | 
				
			||||||
 | 
					    from ipex_llm.utils.ipex_importer import revert_import
 | 
				
			||||||
 | 
					    revert_import()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    loader = GGUFFileLoader(fpath)
 | 
					    loader = GGUFFileLoader(fpath)
 | 
				
			||||||
    model_family = loader.config["general.architecture"]
 | 
					    model_family = loader.config["general.architecture"]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,15 +18,73 @@ from importlib.metadata import distribution, PackageNotFoundError
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import builtins
 | 
					import builtins
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
from ipex_llm.utils.common import log4Error
 | 
					import os
 | 
				
			||||||
import inspect
 | 
					import inspect
 | 
				
			||||||
 | 
					from ipex_llm.utils.common import log4Error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Save the original __import__ function
 | 
					
 | 
				
			||||||
original_import = builtins.__import__
 | 
					# Default is True, set to False to disable IPEX duplicate checker
 | 
				
			||||||
 | 
					BIGDL_CHECK_DUPLICATE_IMPORT = os.getenv("BIGDL_CHECK_DUPLICATE_IMPORT",
 | 
				
			||||||
 | 
					                                         'True').lower() in ('true', '1', 't')
 | 
				
			||||||
 | 
					RAW_IMPORT = None
 | 
				
			||||||
 | 
					IS_IMPORT_REPLACED = False
 | 
				
			||||||
ipex_duplicate_import_error = "intel_extension_for_pytorch has already been automatically " + \
 | 
					ipex_duplicate_import_error = "intel_extension_for_pytorch has already been automatically " + \
 | 
				
			||||||
    "imported. Please avoid importing it again!"
 | 
					    "imported. Please avoid importing it again!"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def replace_import():
 | 
				
			||||||
 | 
					    global RAW_IMPORT, IS_IMPORT_REPLACED
 | 
				
			||||||
 | 
					    # Avoid multiple replacement
 | 
				
			||||||
 | 
					    if not IS_IMPORT_REPLACED and RAW_IMPORT is None:
 | 
				
			||||||
 | 
					        # Save the original __import__ function
 | 
				
			||||||
 | 
					        RAW_IMPORT = builtins.__import__
 | 
				
			||||||
 | 
					        builtins.__import__ = custom_ipex_import
 | 
				
			||||||
 | 
					        IS_IMPORT_REPLACED = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def revert_import():
 | 
				
			||||||
 | 
					    if not BIGDL_CHECK_DUPLICATE_IMPORT:
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    global RAW_IMPORT, IS_IMPORT_REPLACED
 | 
				
			||||||
 | 
					    # Only revert once
 | 
				
			||||||
 | 
					    if RAW_IMPORT is not None and IS_IMPORT_REPLACED:
 | 
				
			||||||
 | 
					        builtins.__import__ = RAW_IMPORT
 | 
				
			||||||
 | 
					        IS_IMPORT_REPLACED = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_calling_package():
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Return calling package name, e.g., ipex_llm.transformers
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # Get the current stack frame
 | 
				
			||||||
 | 
					    frame = inspect.currentframe()
 | 
				
			||||||
 | 
					    # Get the caller's frame
 | 
				
			||||||
 | 
					    caller_frame = frame.f_back.f_back
 | 
				
			||||||
 | 
					    # Get the caller's module
 | 
				
			||||||
 | 
					    module = inspect.getmodule(caller_frame)
 | 
				
			||||||
 | 
					    if module:
 | 
				
			||||||
 | 
					        # Return the module's package name
 | 
				
			||||||
 | 
					        return module.__package__
 | 
				
			||||||
 | 
					    return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def custom_ipex_import(name, globals=None, locals=None, fromlist=(), level=0):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Custom import function to avoid importing ipex again
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if fromlist is not None or '.' in name:
 | 
				
			||||||
 | 
					        return RAW_IMPORT(name, globals, locals, fromlist, level)
 | 
				
			||||||
 | 
					    # Avoid errors in submodule import
 | 
				
			||||||
 | 
					    calling = get_calling_package()
 | 
				
			||||||
 | 
					    if calling is not None:
 | 
				
			||||||
 | 
					        return RAW_IMPORT(name, globals, locals, fromlist, level)
 | 
				
			||||||
 | 
					    # Only check ipex for main thread
 | 
				
			||||||
 | 
					    if name == "ipex" or name == "intel_extension_for_pytorch":
 | 
				
			||||||
 | 
					        log4Error.invalidInputError(False,
 | 
				
			||||||
 | 
					                                    ipex_duplicate_import_error)
 | 
				
			||||||
 | 
					    return RAW_IMPORT(name, globals, locals, fromlist, level)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class IPEXImporter:
 | 
					class IPEXImporter:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Auto import Intel Extension for PyTorch as ipex,
 | 
					    Auto import Intel Extension for PyTorch as ipex,
 | 
				
			||||||
| 
						 | 
					@ -71,6 +129,9 @@ class IPEXImporter:
 | 
				
			||||||
                                            ipex_duplicate_import_error)
 | 
					                                            ipex_duplicate_import_error)
 | 
				
			||||||
            self.directly_import_ipex()
 | 
					            self.directly_import_ipex()
 | 
				
			||||||
            self.ipex_version = ipex.__version__
 | 
					            self.ipex_version = ipex.__version__
 | 
				
			||||||
 | 
					            # Replace builtin import to avoid duplicate ipex import
 | 
				
			||||||
 | 
					            if BIGDL_CHECK_DUPLICATE_IMPORT:
 | 
				
			||||||
 | 
					                replace_import()
 | 
				
			||||||
            logging.info("intel_extension_for_pytorch auto imported")
 | 
					            logging.info("intel_extension_for_pytorch auto imported")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def directly_import_ipex(self):
 | 
					    def directly_import_ipex(self):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue