LLM: Support load_low_bit loading models in shards format (#8612)
* shards_model --------- Co-authored-by: leonardozcm <leonaordo1997zcm@gmail.com>
This commit is contained in:
parent
919791e406
commit
5b484ab48d
2 changed files with 92 additions and 8 deletions
|
|
@ -14,9 +14,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import gc
|
||||||
import transformers
|
import transformers
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from .utils import extract_local_archive_file, load_state_dict, load
|
from .utils import extract_local_archive_file, \
|
||||||
|
load_state_dict, \
|
||||||
|
load, \
|
||||||
|
get_local_shard_files, \
|
||||||
|
fix_key
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
@ -147,12 +152,51 @@ class _BaseAutoModelClass:
|
||||||
# and the tensor shape of int4 weights without quantization.
|
# and the tensor shape of int4 weights without quantization.
|
||||||
model = ggml_convert_quant(model, qtype, convert_shape_only=True)
|
model = ggml_convert_quant(model, qtype, convert_shape_only=True)
|
||||||
# Load the quantized model at last.
|
# Load the quantized model at last.
|
||||||
archive_file = extract_local_archive_file(pretrained_model_name_or_path,
|
resolved_archive_file, is_sharded = extract_local_archive_file(
|
||||||
subfolder,
|
pretrained_model_name_or_path,
|
||||||
variant)
|
subfolder,
|
||||||
state_dict = load_state_dict(archive_file)
|
variant)
|
||||||
load(model, state_dict)
|
if is_sharded:
|
||||||
del state_dict
|
resolved_archive_file, sharded_metadata = \
|
||||||
|
get_local_shard_files(pretrained_model_name_or_path,
|
||||||
|
resolved_archive_file,
|
||||||
|
subfolder=subfolder)
|
||||||
|
start_prefix = ""
|
||||||
|
prefix = model.base_model_prefix
|
||||||
|
loaded_keys = [fix_key(key) for key in sharded_metadata["all_checkpoint_keys"]]
|
||||||
|
if len(prefix) > 0:
|
||||||
|
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
||||||
|
else:
|
||||||
|
has_prefix_module = False
|
||||||
|
|
||||||
|
model_cls = type(model)
|
||||||
|
if len(model_cls.base_model_prefix) > 0 and \
|
||||||
|
not hasattr(model, model_cls.base_model_prefix) and \
|
||||||
|
has_prefix_module:
|
||||||
|
start_prefix = model_cls.base_model_prefix + "."
|
||||||
|
from transformers.modeling_utils import _load_state_dict_into_model
|
||||||
|
error_msgs = []
|
||||||
|
for shard_file in resolved_archive_file:
|
||||||
|
state_dict = load_state_dict(shard_file)
|
||||||
|
error_msgs += _load_state_dict_into_model(model, state_dict, start_prefix)
|
||||||
|
# force memory release
|
||||||
|
del state_dict
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if len(error_msgs) > 0:
|
||||||
|
error_msg = "\n\t".join(error_msgs)
|
||||||
|
if "size mismatch" in error_msg:
|
||||||
|
error_msg += (
|
||||||
|
"\n\tYou may consider adding `ignore_mismatched_sizes=True`"
|
||||||
|
" in the model `from_pretrained` method."
|
||||||
|
)
|
||||||
|
invalidInputError(False, "Error(s) in loading state_dict"
|
||||||
|
f"for {model.__class__.__name__}:\n\t{error_msg}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict(resolved_archive_file)
|
||||||
|
load(model, state_dict)
|
||||||
|
del state_dict
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,7 @@ from torch import nn
|
||||||
|
|
||||||
|
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
|
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||||
|
|
||||||
|
|
||||||
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant):
|
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant):
|
||||||
|
|
@ -59,7 +60,18 @@ def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant
|
||||||
archive_file = os.path.join(
|
archive_file = os.path.join(
|
||||||
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
|
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
|
||||||
)
|
)
|
||||||
return archive_file
|
return archive_file, False
|
||||||
|
elif os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path,
|
||||||
|
subfolder,
|
||||||
|
_add_variant(WEIGHTS_INDEX_NAME, variant))
|
||||||
|
):
|
||||||
|
# Load from a sharded PyTorch checkpoint
|
||||||
|
archive_file = os.path.join(
|
||||||
|
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
|
||||||
|
)
|
||||||
|
is_sharded = True
|
||||||
|
return archive_file, is_sharded
|
||||||
else:
|
else:
|
||||||
invalidInputError(False,
|
invalidInputError(False,
|
||||||
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}"
|
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}"
|
||||||
|
|
@ -89,3 +101,31 @@ def load(module: nn.Module, state_dict, prefix=""):
|
||||||
for name, child in module._modules.items():
|
for name, child in module._modules.items():
|
||||||
if child is not None:
|
if child is not None:
|
||||||
load(child, state_dict, prefix + name + ".")
|
load(child, state_dict, prefix + name + ".")
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_shard_files(pretrained_model_name_or_path, index_filename, subfolder=""):
|
||||||
|
import json
|
||||||
|
|
||||||
|
invalidInputError(os.path.isfile(index_filename),
|
||||||
|
"Can't find a checkpoint index"
|
||||||
|
f" ({index_filename}) in {pretrained_model_name_or_path}.")
|
||||||
|
|
||||||
|
with open(index_filename, "r") as f:
|
||||||
|
index = json.loads(f.read())
|
||||||
|
|
||||||
|
shard_filenames = sorted(set(index["weight_map"].values()))
|
||||||
|
sharded_metadata = index["metadata"]
|
||||||
|
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
|
||||||
|
sharded_metadata["weight_map"] = index["weight_map"].copy()
|
||||||
|
|
||||||
|
shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f)
|
||||||
|
for f in shard_filenames]
|
||||||
|
return shard_filenames, sharded_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def fix_key(key):
|
||||||
|
if "beta" in key:
|
||||||
|
return key.replace("beta", "bias")
|
||||||
|
if "gamma" in key:
|
||||||
|
return key.replace("gamma", "weight")
|
||||||
|
return key
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue