162 lines
5.2 KiB
Python
162 lines
5.2 KiB
Python
#
|
|
# Copyright 2016 The BigDL Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from importlib.metadata import distribution, PackageNotFoundError
|
|
import logging
|
|
import builtins
|
|
import sys
|
|
import os
|
|
import inspect
|
|
from ipex_llm.utils.common import log4Error
|
|
|
|
|
|
# Default is True, set to False to disable IPEX duplicate checker
|
|
BIGDL_CHECK_DUPLICATE_IMPORT = os.getenv("BIGDL_CHECK_DUPLICATE_IMPORT",
|
|
'True').lower() in ('true', '1', 't')
|
|
RAW_IMPORT = None
|
|
IS_IMPORT_REPLACED = False
|
|
ipex_duplicate_import_error = "intel_extension_for_pytorch has already been automatically " + \
|
|
"imported. Please avoid importing it again!"
|
|
|
|
|
|
def replace_import():
|
|
global RAW_IMPORT, IS_IMPORT_REPLACED
|
|
# Avoid multiple replacement
|
|
if not IS_IMPORT_REPLACED and RAW_IMPORT is None:
|
|
# Save the original __import__ function
|
|
RAW_IMPORT = builtins.__import__
|
|
builtins.__import__ = custom_ipex_import
|
|
IS_IMPORT_REPLACED = True
|
|
|
|
|
|
def revert_import():
|
|
if not BIGDL_CHECK_DUPLICATE_IMPORT:
|
|
return
|
|
global RAW_IMPORT, IS_IMPORT_REPLACED
|
|
# Only revert once
|
|
if RAW_IMPORT is not None and IS_IMPORT_REPLACED:
|
|
builtins.__import__ = RAW_IMPORT
|
|
IS_IMPORT_REPLACED = False
|
|
|
|
|
|
def get_calling_package():
|
|
"""
|
|
Return calling package name, e.g., ipex_llm.transformers
|
|
"""
|
|
# Get the current stack frame
|
|
frame = inspect.currentframe()
|
|
# Get the caller's frame
|
|
caller_frame = frame.f_back.f_back
|
|
# Get the caller's module
|
|
module = inspect.getmodule(caller_frame)
|
|
if module:
|
|
# Return the module's package name
|
|
return module.__package__
|
|
return None
|
|
|
|
|
|
def custom_ipex_import(name, globals=None, locals=None, fromlist=(), level=0):
|
|
"""
|
|
Custom import function to avoid importing ipex again
|
|
"""
|
|
if fromlist is not None or '.' in name:
|
|
return RAW_IMPORT(name, globals, locals, fromlist, level)
|
|
# Avoid errors in submodule import
|
|
calling = get_calling_package()
|
|
if calling is not None:
|
|
return RAW_IMPORT(name, globals, locals, fromlist, level)
|
|
# Only check ipex for main thread
|
|
if name == "ipex" or name == "intel_extension_for_pytorch":
|
|
log4Error.invalidInputError(False,
|
|
ipex_duplicate_import_error)
|
|
return RAW_IMPORT(name, globals, locals, fromlist, level)
|
|
|
|
|
|
class IPEXImporter:
|
|
"""
|
|
Auto import Intel Extension for PyTorch as ipex,
|
|
if bigdl-llm xpu version is installed.
|
|
"""
|
|
def __init__(self):
|
|
self.ipex_version = None
|
|
|
|
@staticmethod
|
|
def is_xpu_version_installed():
|
|
"""
|
|
Check if bigdl-llm xpu version is install
|
|
|
|
Returns ture if installed false if not
|
|
"""
|
|
# Check if xpu version installed
|
|
try:
|
|
# Check if bigdl-core-xe is installed
|
|
distribution('bigdl-core-xe')
|
|
return True
|
|
except PackageNotFoundError:
|
|
# bigdl-core-xe not found
|
|
# Check if bigdl-core-xe-21 is installed
|
|
try:
|
|
distribution('bigdl-core-xe-21')
|
|
return True
|
|
except PackageNotFoundError:
|
|
# bigdl-core-xe not found
|
|
return False
|
|
return False
|
|
|
|
def import_ipex(self):
|
|
"""
|
|
Try to import Intel Extension for PyTorch as ipex for XPU
|
|
|
|
Raises ImportError and invalidInputError if failed
|
|
"""
|
|
if self.is_xpu_version_installed():
|
|
# Check if user import ipex manually
|
|
if BIGDL_CHECK_DUPLICATE_IMPORT:
|
|
if 'ipex' in sys.modules or 'intel_extension_for_pytorch' in sys.modules:
|
|
log4Error.invalidInputError(False,
|
|
ipex_duplicate_import_error)
|
|
self.directly_import_ipex()
|
|
self.ipex_version = ipex.__version__
|
|
# Replace builtin import to avoid duplicate ipex import
|
|
if BIGDL_CHECK_DUPLICATE_IMPORT:
|
|
replace_import()
|
|
logging.info("intel_extension_for_pytorch auto imported")
|
|
|
|
def directly_import_ipex(self):
|
|
"""
|
|
Try to import Intel Extension for PyTorch as ipex
|
|
|
|
Raises ImportError and invalidInputError if failed
|
|
"""
|
|
# import ipex
|
|
import intel_extension_for_pytorch as ipex
|
|
if ipex is not None:
|
|
# Expose ipex to Python builtins
|
|
builtins.ipex = ipex
|
|
else:
|
|
log4Error.invalidInputError(False,
|
|
"Can not import intel_extension_for_pytorch.")
|
|
|
|
def get_ipex_version(self):
|
|
"""
|
|
Get ipex version
|
|
|
|
Raises ImportError if cannot import Intel Extension for PyTorch
|
|
"""
|
|
return self.ipex_version
|
|
|
|
|
|
ipex_importer = IPEXImporter()
|