LLM: Mute shape mismatch output (#8601)

* LLM: Mute shape mismatch output
This commit is contained in:
Zhao Changmin 2023-08-02 16:46:22 +08:00 committed by GitHub
parent 15b3adc7ec
commit ca998cc6f2
3 changed files with 17 additions and 6 deletions

View file

@ -23,7 +23,7 @@ from .utils import extract_local_archive_file, \
get_local_shard_files, \
fix_key
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.utils.common import invalidInputError, MuteHFLogger
def save_low_bit(self, *args, **kwargs):
@ -166,10 +166,9 @@ class _BaseAutoModelClass:
variant = kwargs.get("variant", None)
from .convert import ggml_convert_quant
model = cls.HF_Model.from_pretrained(*args, **kwargs)
print("Note: If there are warnings during the model loading process, "
"they can be safely ignored; "
"the model will be loaded with INT4 optimizations applied.")
with MuteHFLogger(logger=transformers.modeling_utils.logger):
model = cls.HF_Model.from_pretrained(*args, **kwargs)
# add save_low_bit to pretrained model dynamically
import types

View file

@ -19,5 +19,5 @@
# Otherwise there would be module not found error in non-pip's setting as Python would
# only search the first bigdl package and end up finding only one sub-package.
from .log4Error import invalidInputError, invalidOperationError
from .log4Error import invalidInputError, invalidOperationError, MuteHFLogger
from .lazyimport import LazyImport

View file

@ -39,3 +39,15 @@ def invalidOperationError(condition, errMsg, fixMsg=None, cause=None):
raise cause
else:
raise RuntimeError(errMsg)
class MuteHFLogger():
def __init__(self, logger, speak_level=logging.ERROR) -> None:
self.logger = logger
self.speak_level = speak_level
self.old_level = logger.getEffectiveLevel()
def __enter__(self):
self.logger.setLevel(self.speak_level)
def __exit__(self, exc_type, exc_value, traceback):
self.logger.setLevel(self.old_level)