LLM: Enable attempting loading method automatically (#8841)

* enable auto load method

* warning error

* logger info

---------

Co-authored-by: leonardozcm <leonardozcm@gmail.com>
This commit is contained in:
Zhao Changmin 2023-08-30 15:41:55 +08:00 committed by GitHub
parent bba73ec9d2
commit 731916c639
2 changed files with 29 additions and 5 deletions

View file

@ -85,12 +85,13 @@ def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
# Remove the last key for recursion
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_quant_linear(
_, _flag = _replace_with_quant_linear(
module,
qtype,
modules_to_not_convert,
current_key_name,
)
has_been_replaced = _flag or has_been_replaced
return model, has_been_replaced

View file

@ -22,6 +22,11 @@ from .utils import extract_local_archive_file, \
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError
import torch
import copy
import logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
def save_low_bit(self, *args, **kwargs):
@ -98,7 +103,19 @@ class _BaseAutoModelClass:
f"Unknown load_in_low_bit value: {q_k}, expected:"
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
qtype = ggml_tensor_qtype[q_k]
model = cls.HF_Model.from_pretrained(*args, **kwargs)
# In case it needs a second try,
# `from_pretrained`` may pop items out in dict
# and lead to args missing.
_args = copy.deepcopy(args)
_kwargs = copy.deepcopy(kwargs)
try:
model = cls.HF_Model.from_pretrained(*args, **kwargs)
except NotImplementedError:
logger.info("Failed to load models with `low_cpu_mem_usage` specified, "
"will fall to traditional load method with higher memory consumption.")
_kwargs["low_cpu_mem_usage"] = False
model = cls.HF_Model.from_pretrained(*_args, **_kwargs)
model.config.update({"bigdl_lcmu_enabled": False})
model = model.to("cpu")
model = ggml_convert_quant(model, qtype, optimize_model)
model.config.update({"bigdl_transformers_low_bit": q_k})
@ -151,6 +168,7 @@ class _BaseAutoModelClass:
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True)
invalidInputError(bigdl_transformers_low_bit,
"Detect this model is not a low-bit model, Please use from_pretrained"
@ -226,10 +244,15 @@ class _BaseAutoModelClass:
init_contexts = [no_init_weights(_enable=_fast_init)]
init_contexts.append(init_empty_weights())
with ContextManagers(init_contexts):
if bigdl_lcmu_enabled:
with ContextManagers(init_contexts):
model = model_class(config, *model_args, **kwargs)
else:
model = model_class(config, *model_args, **kwargs)
model = ggml_convert_quant(model, qtype, optimize_model, device="meta")
# Loading args may differ based on their usage
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
model = ggml_convert_quant(model, qtype, optimize_model, device=quant_device)
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
@ -260,7 +283,7 @@ class _BaseAutoModelClass:
pretrained_model_name_or_path,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=True,
low_cpu_mem_usage=bigdl_lcmu_enabled,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,