LLM: Support load_low_bit loading models in shards format (#8612)

* shards_model

---------

Co-authored-by: leonardozcm <leonaordo1997zcm@gmail.com>
This commit is contained in:
Zhao Changmin 2023-07-26 13:30:01 +08:00 committed by GitHub
parent 919791e406
commit 5b484ab48d
2 changed files with 92 additions and 8 deletions

View file

@ -14,9 +14,14 @@
# limitations under the License.
#
import gc
import transformers
from transformers.configuration_utils import PretrainedConfig
from .utils import extract_local_archive_file, load_state_dict, load
from .utils import extract_local_archive_file, \
load_state_dict, \
load, \
get_local_shard_files, \
fix_key
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError
@ -147,12 +152,51 @@ class _BaseAutoModelClass:
# and the tensor shape of int4 weights without quantization.
model = ggml_convert_quant(model, qtype, convert_shape_only=True)
# Load the quantized model at last.
archive_file = extract_local_archive_file(pretrained_model_name_or_path,
subfolder,
variant)
state_dict = load_state_dict(archive_file)
load(model, state_dict)
del state_dict
resolved_archive_file, is_sharded = extract_local_archive_file(
pretrained_model_name_or_path,
subfolder,
variant)
if is_sharded:
resolved_archive_file, sharded_metadata = \
get_local_shard_files(pretrained_model_name_or_path,
resolved_archive_file,
subfolder=subfolder)
start_prefix = ""
prefix = model.base_model_prefix
loaded_keys = [fix_key(key) for key in sharded_metadata["all_checkpoint_keys"]]
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
else:
has_prefix_module = False
model_cls = type(model)
if len(model_cls.base_model_prefix) > 0 and \
not hasattr(model, model_cls.base_model_prefix) and \
has_prefix_module:
start_prefix = model_cls.base_model_prefix + "."
from transformers.modeling_utils import _load_state_dict_into_model
error_msgs = []
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
error_msgs += _load_state_dict_into_model(model, state_dict, start_prefix)
# force memory release
del state_dict
gc.collect()
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:
error_msg += (
"\n\tYou may consider adding `ignore_mismatched_sizes=True`"
" in the model `from_pretrained` method."
)
invalidInputError(False, "Error(s) in loading state_dict"
f"for {model.__class__.__name__}:\n\t{error_msg}")
else:
state_dict = load_state_dict(resolved_archive_file)
load(model, state_dict)
del state_dict
return model

View file

@ -48,6 +48,7 @@ from torch import nn
WEIGHTS_NAME = "pytorch_model.bin"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant):
@ -59,7 +60,18 @@ def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
)
return archive_file
return archive_file, False
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path,
subfolder,
_add_variant(WEIGHTS_INDEX_NAME, variant))
):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
)
is_sharded = True
return archive_file, is_sharded
else:
invalidInputError(False,
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}"
@ -89,3 +101,31 @@ def load(module: nn.Module, state_dict, prefix=""):
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
def get_local_shard_files(pretrained_model_name_or_path, index_filename, subfolder=""):
import json
invalidInputError(os.path.isfile(index_filename),
"Can't find a checkpoint index"
f" ({index_filename}) in {pretrained_model_name_or_path}.")
with open(index_filename, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values()))
sharded_metadata = index["metadata"]
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
sharded_metadata["weight_map"] = index["weight_map"].copy()
shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f)
for f in shard_filenames]
return shard_filenames, sharded_metadata
def fix_key(key):
if "beta" in key:
return key.replace("beta", "bias")
if "gamma" in key:
return key.replace("gamma", "weight")
return key