LLM: Support load_low_bit loading models in shards format (#8612)
* shards_model --------- Co-authored-by: leonardozcm <leonaordo1997zcm@gmail.com>
This commit is contained in:
		
							parent
							
								
									919791e406
								
							
						
					
					
						commit
						5b484ab48d
					
				
					 2 changed files with 92 additions and 8 deletions
				
			
		| 
						 | 
				
			
			@ -14,9 +14,14 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import gc
 | 
			
		||||
import transformers
 | 
			
		||||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
from .utils import extract_local_archive_file, load_state_dict, load
 | 
			
		||||
from .utils import extract_local_archive_file, \
 | 
			
		||||
    load_state_dict, \
 | 
			
		||||
    load, \
 | 
			
		||||
    get_local_shard_files, \
 | 
			
		||||
    fix_key
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -147,12 +152,51 @@ class _BaseAutoModelClass:
 | 
			
		|||
        # and the tensor shape of int4 weights without quantization.
 | 
			
		||||
        model = ggml_convert_quant(model, qtype, convert_shape_only=True)
 | 
			
		||||
        # Load the quantized model at last.
 | 
			
		||||
        archive_file = extract_local_archive_file(pretrained_model_name_or_path,
 | 
			
		||||
                                                  subfolder,
 | 
			
		||||
                                                  variant)
 | 
			
		||||
        state_dict = load_state_dict(archive_file)
 | 
			
		||||
        load(model, state_dict)
 | 
			
		||||
        del state_dict
 | 
			
		||||
        resolved_archive_file, is_sharded = extract_local_archive_file(
 | 
			
		||||
            pretrained_model_name_or_path,
 | 
			
		||||
            subfolder,
 | 
			
		||||
            variant)
 | 
			
		||||
        if is_sharded:
 | 
			
		||||
            resolved_archive_file, sharded_metadata = \
 | 
			
		||||
                get_local_shard_files(pretrained_model_name_or_path,
 | 
			
		||||
                                      resolved_archive_file,
 | 
			
		||||
                                      subfolder=subfolder)
 | 
			
		||||
            start_prefix = ""
 | 
			
		||||
            prefix = model.base_model_prefix
 | 
			
		||||
            loaded_keys = [fix_key(key) for key in sharded_metadata["all_checkpoint_keys"]]
 | 
			
		||||
            if len(prefix) > 0:
 | 
			
		||||
                has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
 | 
			
		||||
            else:
 | 
			
		||||
                has_prefix_module = False
 | 
			
		||||
 | 
			
		||||
            model_cls = type(model)
 | 
			
		||||
            if len(model_cls.base_model_prefix) > 0 and \
 | 
			
		||||
                not hasattr(model, model_cls.base_model_prefix) and \
 | 
			
		||||
                    has_prefix_module:
 | 
			
		||||
                start_prefix = model_cls.base_model_prefix + "."
 | 
			
		||||
            from transformers.modeling_utils import _load_state_dict_into_model
 | 
			
		||||
            error_msgs = []
 | 
			
		||||
            for shard_file in resolved_archive_file:
 | 
			
		||||
                state_dict = load_state_dict(shard_file)
 | 
			
		||||
                error_msgs += _load_state_dict_into_model(model, state_dict, start_prefix)
 | 
			
		||||
                # force memory release
 | 
			
		||||
                del state_dict
 | 
			
		||||
                gc.collect()
 | 
			
		||||
 | 
			
		||||
            if len(error_msgs) > 0:
 | 
			
		||||
                error_msg = "\n\t".join(error_msgs)
 | 
			
		||||
                if "size mismatch" in error_msg:
 | 
			
		||||
                    error_msg += (
 | 
			
		||||
                        "\n\tYou may consider adding `ignore_mismatched_sizes=True`"
 | 
			
		||||
                        " in the model `from_pretrained` method."
 | 
			
		||||
                    )
 | 
			
		||||
                invalidInputError(False, "Error(s) in loading state_dict"
 | 
			
		||||
                                         f"for {model.__class__.__name__}:\n\t{error_msg}")
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            state_dict = load_state_dict(resolved_archive_file)
 | 
			
		||||
            load(model, state_dict)
 | 
			
		||||
            del state_dict
 | 
			
		||||
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -48,6 +48,7 @@ from torch import nn
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
WEIGHTS_NAME = "pytorch_model.bin"
 | 
			
		||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant):
 | 
			
		||||
| 
						 | 
				
			
			@ -59,7 +60,18 @@ def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant
 | 
			
		|||
        archive_file = os.path.join(
 | 
			
		||||
            pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
 | 
			
		||||
        )
 | 
			
		||||
        return archive_file
 | 
			
		||||
        return archive_file, False
 | 
			
		||||
    elif os.path.isfile(
 | 
			
		||||
        os.path.join(pretrained_model_name_or_path,
 | 
			
		||||
                     subfolder,
 | 
			
		||||
                     _add_variant(WEIGHTS_INDEX_NAME, variant))
 | 
			
		||||
    ):
 | 
			
		||||
        # Load from a sharded PyTorch checkpoint
 | 
			
		||||
        archive_file = os.path.join(
 | 
			
		||||
            pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
 | 
			
		||||
        )
 | 
			
		||||
        is_sharded = True
 | 
			
		||||
        return archive_file, is_sharded
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}"
 | 
			
		||||
| 
						 | 
				
			
			@ -89,3 +101,31 @@ def load(module: nn.Module, state_dict, prefix=""):
 | 
			
		|||
    for name, child in module._modules.items():
 | 
			
		||||
        if child is not None:
 | 
			
		||||
            load(child, state_dict, prefix + name + ".")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_local_shard_files(pretrained_model_name_or_path, index_filename, subfolder=""):
 | 
			
		||||
    import json
 | 
			
		||||
 | 
			
		||||
    invalidInputError(os.path.isfile(index_filename),
 | 
			
		||||
                      "Can't find a checkpoint index"
 | 
			
		||||
                      f" ({index_filename}) in {pretrained_model_name_or_path}.")
 | 
			
		||||
 | 
			
		||||
    with open(index_filename, "r") as f:
 | 
			
		||||
        index = json.loads(f.read())
 | 
			
		||||
 | 
			
		||||
    shard_filenames = sorted(set(index["weight_map"].values()))
 | 
			
		||||
    sharded_metadata = index["metadata"]
 | 
			
		||||
    sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
 | 
			
		||||
    sharded_metadata["weight_map"] = index["weight_map"].copy()
 | 
			
		||||
 | 
			
		||||
    shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f)
 | 
			
		||||
                       for f in shard_filenames]
 | 
			
		||||
    return shard_filenames, sharded_metadata
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fix_key(key):
 | 
			
		||||
    if "beta" in key:
 | 
			
		||||
        return key.replace("beta", "bias")
 | 
			
		||||
    if "gamma" in key:
 | 
			
		||||
        return key.replace("gamma", "weight")
 | 
			
		||||
    return key
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue