diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index f9424ccc..29670b1d 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -14,9 +14,14 @@ # limitations under the License. # +import gc import transformers from transformers.configuration_utils import PretrainedConfig -from .utils import extract_local_archive_file, load_state_dict, load +from .utils import extract_local_archive_file, \ + load_state_dict, \ + load, \ + get_local_shard_files, \ + fix_key from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError @@ -147,12 +152,51 @@ class _BaseAutoModelClass: # and the tensor shape of int4 weights without quantization. model = ggml_convert_quant(model, qtype, convert_shape_only=True) # Load the quantized model at last. - archive_file = extract_local_archive_file(pretrained_model_name_or_path, - subfolder, - variant) - state_dict = load_state_dict(archive_file) - load(model, state_dict) - del state_dict + resolved_archive_file, is_sharded = extract_local_archive_file( + pretrained_model_name_or_path, + subfolder, + variant) + if is_sharded: + resolved_archive_file, sharded_metadata = \ + get_local_shard_files(pretrained_model_name_or_path, + resolved_archive_file, + subfolder=subfolder) + start_prefix = "" + prefix = model.base_model_prefix + loaded_keys = [fix_key(key) for key in sharded_metadata["all_checkpoint_keys"]] + if len(prefix) > 0: + has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) + else: + has_prefix_module = False + + model_cls = type(model) + if len(model_cls.base_model_prefix) > 0 and \ + not hasattr(model, model_cls.base_model_prefix) and \ + has_prefix_module: + start_prefix = model_cls.base_model_prefix + "." + from transformers.modeling_utils import _load_state_dict_into_model + error_msgs = [] + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + error_msgs += _load_state_dict_into_model(model, state_dict, start_prefix) + # force memory release + del state_dict + gc.collect() + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True`" + " in the model `from_pretrained` method." + ) + invalidInputError(False, "Error(s) in loading state_dict" + f"for {model.__class__.__name__}:\n\t{error_msg}") + + else: + state_dict = load_state_dict(resolved_archive_file) + load(model, state_dict) + del state_dict return model diff --git a/python/llm/src/bigdl/llm/transformers/utils.py b/python/llm/src/bigdl/llm/transformers/utils.py index 837fd5dc..7be59e4f 100644 --- a/python/llm/src/bigdl/llm/transformers/utils.py +++ b/python/llm/src/bigdl/llm/transformers/utils.py @@ -48,6 +48,7 @@ from torch import nn WEIGHTS_NAME = "pytorch_model.bin" +WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant): @@ -59,7 +60,18 @@ def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant archive_file = os.path.join( pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) ) - return archive_file + return archive_file, False + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, + subfolder, + _add_variant(WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + ) + is_sharded = True + return archive_file, is_sharded else: invalidInputError(False, f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}" @@ -89,3 +101,31 @@ def load(module: nn.Module, state_dict, prefix=""): for name, child in module._modules.items(): if child is not None: load(child, state_dict, prefix + name + ".") + + +def get_local_shard_files(pretrained_model_name_or_path, index_filename, subfolder=""): + import json + + invalidInputError(os.path.isfile(index_filename), + "Can't find a checkpoint index" + f" ({index_filename}) in {pretrained_model_name_or_path}.") + + with open(index_filename, "r") as f: + index = json.loads(f.read()) + + shard_filenames = sorted(set(index["weight_map"].values())) + sharded_metadata = index["metadata"] + sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) + sharded_metadata["weight_map"] = index["weight_map"].copy() + + shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) + for f in shard_filenames] + return shard_filenames, sharded_metadata + + +def fix_key(key): + if "beta" in key: + return key.replace("beta", "bias") + if "gamma" in key: + return key.replace("gamma", "weight") + return key