LLM: Enable attempting loading method automatically (#8841)
* enable auto load method * warning error * logger info --------- Co-authored-by: leonardozcm <leonardozcm@gmail.com>
This commit is contained in:
parent
bba73ec9d2
commit
731916c639
2 changed files with 29 additions and 5 deletions
|
|
@ -85,12 +85,13 @@ def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
|
|||
|
||||
# Remove the last key for recursion
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_quant_linear(
|
||||
_, _flag = _replace_with_quant_linear(
|
||||
module,
|
||||
qtype,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
)
|
||||
has_been_replaced = _flag or has_been_replaced
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,11 @@ from .utils import extract_local_archive_file, \
|
|||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
import torch
|
||||
import copy
|
||||
import logging
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def save_low_bit(self, *args, **kwargs):
|
||||
|
|
@ -98,7 +103,19 @@ class _BaseAutoModelClass:
|
|||
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
||||
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
||||
qtype = ggml_tensor_qtype[q_k]
|
||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||
# In case it needs a second try,
|
||||
# `from_pretrained`` may pop items out in dict
|
||||
# and lead to args missing.
|
||||
_args = copy.deepcopy(args)
|
||||
_kwargs = copy.deepcopy(kwargs)
|
||||
try:
|
||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||
except NotImplementedError:
|
||||
logger.info("Failed to load models with `low_cpu_mem_usage` specified, "
|
||||
"will fall to traditional load method with higher memory consumption.")
|
||||
_kwargs["low_cpu_mem_usage"] = False
|
||||
model = cls.HF_Model.from_pretrained(*_args, **_kwargs)
|
||||
model.config.update({"bigdl_lcmu_enabled": False})
|
||||
model = model.to("cpu")
|
||||
model = ggml_convert_quant(model, qtype, optimize_model)
|
||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||
|
|
@ -151,6 +168,7 @@ class _BaseAutoModelClass:
|
|||
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
||||
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
|
||||
bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True)
|
||||
|
||||
invalidInputError(bigdl_transformers_low_bit,
|
||||
"Detect this model is not a low-bit model, Please use from_pretrained"
|
||||
|
|
@ -226,10 +244,15 @@ class _BaseAutoModelClass:
|
|||
init_contexts = [no_init_weights(_enable=_fast_init)]
|
||||
init_contexts.append(init_empty_weights())
|
||||
|
||||
with ContextManagers(init_contexts):
|
||||
if bigdl_lcmu_enabled:
|
||||
with ContextManagers(init_contexts):
|
||||
model = model_class(config, *model_args, **kwargs)
|
||||
else:
|
||||
model = model_class(config, *model_args, **kwargs)
|
||||
|
||||
model = ggml_convert_quant(model, qtype, optimize_model, device="meta")
|
||||
# Loading args may differ based on their usage
|
||||
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
|
||||
model = ggml_convert_quant(model, qtype, optimize_model, device=quant_device)
|
||||
|
||||
if is_sharded:
|
||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||
|
|
@ -260,7 +283,7 @@ class _BaseAutoModelClass:
|
|||
pretrained_model_name_or_path,
|
||||
sharded_metadata=sharded_metadata,
|
||||
_fast_init=_fast_init,
|
||||
low_cpu_mem_usage=True,
|
||||
low_cpu_mem_usage=bigdl_lcmu_enabled,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
|
|
|
|||
Loading…
Reference in a new issue