LLM: Mute shape mismatch output (#8601)
* LLM: Mute shape mismatch output
This commit is contained in:
parent
15b3adc7ec
commit
ca998cc6f2
3 changed files with 17 additions and 6 deletions
|
|
@ -23,7 +23,7 @@ from .utils import extract_local_archive_file, \
|
||||||
get_local_shard_files, \
|
get_local_shard_files, \
|
||||||
fix_key
|
fix_key
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError, MuteHFLogger
|
||||||
|
|
||||||
|
|
||||||
def save_low_bit(self, *args, **kwargs):
|
def save_low_bit(self, *args, **kwargs):
|
||||||
|
|
@ -166,10 +166,9 @@ class _BaseAutoModelClass:
|
||||||
variant = kwargs.get("variant", None)
|
variant = kwargs.get("variant", None)
|
||||||
|
|
||||||
from .convert import ggml_convert_quant
|
from .convert import ggml_convert_quant
|
||||||
|
|
||||||
|
with MuteHFLogger(logger=transformers.modeling_utils.logger):
|
||||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||||
print("Note: If there are warnings during the model loading process, "
|
|
||||||
"they can be safely ignored; "
|
|
||||||
"the model will be loaded with INT4 optimizations applied.")
|
|
||||||
|
|
||||||
# add save_low_bit to pretrained model dynamically
|
# add save_low_bit to pretrained model dynamically
|
||||||
import types
|
import types
|
||||||
|
|
|
||||||
|
|
@ -19,5 +19,5 @@
|
||||||
# Otherwise there would be module not found error in non-pip's setting as Python would
|
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||||
# only search the first bigdl package and end up finding only one sub-package.
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
from .log4Error import invalidInputError, invalidOperationError
|
from .log4Error import invalidInputError, invalidOperationError, MuteHFLogger
|
||||||
from .lazyimport import LazyImport
|
from .lazyimport import LazyImport
|
||||||
|
|
|
||||||
|
|
@ -39,3 +39,15 @@ def invalidOperationError(condition, errMsg, fixMsg=None, cause=None):
|
||||||
raise cause
|
raise cause
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(errMsg)
|
raise RuntimeError(errMsg)
|
||||||
|
|
||||||
|
class MuteHFLogger():
|
||||||
|
def __init__(self, logger, speak_level=logging.ERROR) -> None:
|
||||||
|
self.logger = logger
|
||||||
|
self.speak_level = speak_level
|
||||||
|
self.old_level = logger.getEffectiveLevel()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.logger.setLevel(self.speak_level)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.logger.setLevel(self.old_level)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue