LLM: Mute shape mismatch output (#8601)
* LLM: Mute shape mismatch output
This commit is contained in:
		
							parent
							
								
									15b3adc7ec
								
							
						
					
					
						commit
						ca998cc6f2
					
				
					 3 changed files with 17 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -23,7 +23,7 @@ from .utils import extract_local_archive_file, \
 | 
			
		|||
    get_local_shard_files, \
 | 
			
		||||
    fix_key
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError, MuteHFLogger
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def save_low_bit(self, *args, **kwargs):
 | 
			
		||||
| 
						 | 
				
			
			@ -166,10 +166,9 @@ class _BaseAutoModelClass:
 | 
			
		|||
        variant = kwargs.get("variant", None)
 | 
			
		||||
 | 
			
		||||
        from .convert import ggml_convert_quant
 | 
			
		||||
        model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
			
		||||
        print("Note: If there are warnings during the model loading process, "
 | 
			
		||||
              "they can be safely ignored; "
 | 
			
		||||
              "the model will be loaded with INT4 optimizations applied.")
 | 
			
		||||
 | 
			
		||||
        with MuteHFLogger(logger=transformers.modeling_utils.logger):
 | 
			
		||||
            model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # add save_low_bit to pretrained model dynamically
 | 
			
		||||
        import types
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,5 +19,5 @@
 | 
			
		|||
# Otherwise there would be module not found error in non-pip's setting as Python would
 | 
			
		||||
# only search the first bigdl package and end up finding only one sub-package.
 | 
			
		||||
 | 
			
		||||
from .log4Error import invalidInputError, invalidOperationError
 | 
			
		||||
from .log4Error import invalidInputError, invalidOperationError, MuteHFLogger
 | 
			
		||||
from .lazyimport import LazyImport
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -39,3 +39,15 @@ def invalidOperationError(condition, errMsg, fixMsg=None, cause=None):
 | 
			
		|||
            raise cause
 | 
			
		||||
        else:
 | 
			
		||||
            raise RuntimeError(errMsg)
 | 
			
		||||
 | 
			
		||||
class MuteHFLogger():
 | 
			
		||||
    def __init__(self, logger, speak_level=logging.ERROR) -> None:
 | 
			
		||||
        self.logger = logger
 | 
			
		||||
        self.speak_level = speak_level
 | 
			
		||||
        self.old_level = logger.getEffectiveLevel()
 | 
			
		||||
 | 
			
		||||
    def __enter__(self):
 | 
			
		||||
        self.logger.setLevel(self.speak_level)
 | 
			
		||||
 | 
			
		||||
    def __exit__(self, exc_type, exc_value, traceback):
 | 
			
		||||
        self.logger.setLevel(self.old_level)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue