LLM: Support load_low_bit loading models in shards format (#8612)
* shards_model --------- Co-authored-by: leonardozcm <leonaordo1997zcm@gmail.com>
This commit is contained in:
parent
919791e406
commit
5b484ab48d
2 changed files with 92 additions and 8 deletions
|
|
@ -14,9 +14,14 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import gc
|
||||
import transformers
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from .utils import extract_local_archive_file, load_state_dict, load
|
||||
from .utils import extract_local_archive_file, \
|
||||
load_state_dict, \
|
||||
load, \
|
||||
get_local_shard_files, \
|
||||
fix_key
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
|
|
@ -147,10 +152,49 @@ class _BaseAutoModelClass:
|
|||
# and the tensor shape of int4 weights without quantization.
|
||||
model = ggml_convert_quant(model, qtype, convert_shape_only=True)
|
||||
# Load the quantized model at last.
|
||||
archive_file = extract_local_archive_file(pretrained_model_name_or_path,
|
||||
resolved_archive_file, is_sharded = extract_local_archive_file(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder,
|
||||
variant)
|
||||
state_dict = load_state_dict(archive_file)
|
||||
if is_sharded:
|
||||
resolved_archive_file, sharded_metadata = \
|
||||
get_local_shard_files(pretrained_model_name_or_path,
|
||||
resolved_archive_file,
|
||||
subfolder=subfolder)
|
||||
start_prefix = ""
|
||||
prefix = model.base_model_prefix
|
||||
loaded_keys = [fix_key(key) for key in sharded_metadata["all_checkpoint_keys"]]
|
||||
if len(prefix) > 0:
|
||||
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
||||
else:
|
||||
has_prefix_module = False
|
||||
|
||||
model_cls = type(model)
|
||||
if len(model_cls.base_model_prefix) > 0 and \
|
||||
not hasattr(model, model_cls.base_model_prefix) and \
|
||||
has_prefix_module:
|
||||
start_prefix = model_cls.base_model_prefix + "."
|
||||
from transformers.modeling_utils import _load_state_dict_into_model
|
||||
error_msgs = []
|
||||
for shard_file in resolved_archive_file:
|
||||
state_dict = load_state_dict(shard_file)
|
||||
error_msgs += _load_state_dict_into_model(model, state_dict, start_prefix)
|
||||
# force memory release
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
if len(error_msgs) > 0:
|
||||
error_msg = "\n\t".join(error_msgs)
|
||||
if "size mismatch" in error_msg:
|
||||
error_msg += (
|
||||
"\n\tYou may consider adding `ignore_mismatched_sizes=True`"
|
||||
" in the model `from_pretrained` method."
|
||||
)
|
||||
invalidInputError(False, "Error(s) in loading state_dict"
|
||||
f"for {model.__class__.__name__}:\n\t{error_msg}")
|
||||
|
||||
else:
|
||||
state_dict = load_state_dict(resolved_archive_file)
|
||||
load(model, state_dict)
|
||||
del state_dict
|
||||
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ from torch import nn
|
|||
|
||||
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||
|
||||
|
||||
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant):
|
||||
|
|
@ -59,7 +60,18 @@ def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant
|
|||
archive_file = os.path.join(
|
||||
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
|
||||
)
|
||||
return archive_file
|
||||
return archive_file, False
|
||||
elif os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path,
|
||||
subfolder,
|
||||
_add_variant(WEIGHTS_INDEX_NAME, variant))
|
||||
):
|
||||
# Load from a sharded PyTorch checkpoint
|
||||
archive_file = os.path.join(
|
||||
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
|
||||
)
|
||||
is_sharded = True
|
||||
return archive_file, is_sharded
|
||||
else:
|
||||
invalidInputError(False,
|
||||
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}"
|
||||
|
|
@ -89,3 +101,31 @@ def load(module: nn.Module, state_dict, prefix=""):
|
|||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, state_dict, prefix + name + ".")
|
||||
|
||||
|
||||
def get_local_shard_files(pretrained_model_name_or_path, index_filename, subfolder=""):
|
||||
import json
|
||||
|
||||
invalidInputError(os.path.isfile(index_filename),
|
||||
"Can't find a checkpoint index"
|
||||
f" ({index_filename}) in {pretrained_model_name_or_path}.")
|
||||
|
||||
with open(index_filename, "r") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
shard_filenames = sorted(set(index["weight_map"].values()))
|
||||
sharded_metadata = index["metadata"]
|
||||
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
|
||||
sharded_metadata["weight_map"] = index["weight_map"].copy()
|
||||
|
||||
shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f)
|
||||
for f in shard_filenames]
|
||||
return shard_filenames, sharded_metadata
|
||||
|
||||
|
||||
def fix_key(key):
|
||||
if "beta" in key:
|
||||
return key.replace("beta", "bias")
|
||||
if "gamma" in key:
|
||||
return key.replace("gamma", "weight")
|
||||
return key
|
||||
|
|
|
|||
Loading…
Reference in a new issue