From 731916c639ca08b258a93a845886c28b06de3811 Mon Sep 17 00:00:00 2001 From: Zhao Changmin Date: Wed, 30 Aug 2023 15:41:55 +0800 Subject: [PATCH] LLM: Enable attempting loading method automatically (#8841) * enable auto load method * warning error * logger info --------- Co-authored-by: leonardozcm --- .../llm/src/bigdl/llm/transformers/convert.py | 3 +- .../llm/src/bigdl/llm/transformers/model.py | 31 ++++++++++++++++--- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 72624549..c70b7440 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -85,12 +85,13 @@ def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None, # Remove the last key for recursion if len(list(module.children())) > 0: - _, has_been_replaced = _replace_with_quant_linear( + _, _flag = _replace_with_quant_linear( module, qtype, modules_to_not_convert, current_key_name, ) + has_been_replaced = _flag or has_been_replaced return model, has_been_replaced diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 4d6231ce..75bf196e 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -22,6 +22,11 @@ from .utils import extract_local_archive_file, \ from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError import torch +import copy +import logging + +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) +logger = logging.getLogger(__name__) def save_low_bit(self, *args, **kwargs): @@ -98,7 +103,19 @@ class _BaseAutoModelClass: f"Unknown load_in_low_bit value: {q_k}, expected:" f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.") qtype = ggml_tensor_qtype[q_k] - model = cls.HF_Model.from_pretrained(*args, **kwargs) + # In case it needs a second try, + # `from_pretrained`` may pop items out in dict + # and lead to args missing. + _args = copy.deepcopy(args) + _kwargs = copy.deepcopy(kwargs) + try: + model = cls.HF_Model.from_pretrained(*args, **kwargs) + except NotImplementedError: + logger.info("Failed to load models with `low_cpu_mem_usage` specified, " + "will fall to traditional load method with higher memory consumption.") + _kwargs["low_cpu_mem_usage"] = False + model = cls.HF_Model.from_pretrained(*_args, **_kwargs) + model.config.update({"bigdl_lcmu_enabled": False}) model = model.to("cpu") model = ggml_convert_quant(model, qtype, optimize_model) model.config.update({"bigdl_transformers_low_bit": q_k}) @@ -151,6 +168,7 @@ class _BaseAutoModelClass: config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path) bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False) + bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True) invalidInputError(bigdl_transformers_low_bit, "Detect this model is not a low-bit model, Please use from_pretrained" @@ -226,10 +244,15 @@ class _BaseAutoModelClass: init_contexts = [no_init_weights(_enable=_fast_init)] init_contexts.append(init_empty_weights()) - with ContextManagers(init_contexts): + if bigdl_lcmu_enabled: + with ContextManagers(init_contexts): + model = model_class(config, *model_args, **kwargs) + else: model = model_class(config, *model_args, **kwargs) - model = ggml_convert_quant(model, qtype, optimize_model, device="meta") + # Loading args may differ based on their usage + quant_device = "meta" if bigdl_lcmu_enabled else "cpu" + model = ggml_convert_quant(model, qtype, optimize_model, device=quant_device) if is_sharded: loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] @@ -260,7 +283,7 @@ class _BaseAutoModelClass: pretrained_model_name_or_path, sharded_metadata=sharded_metadata, _fast_init=_fast_init, - low_cpu_mem_usage=True, + low_cpu_mem_usage=bigdl_lcmu_enabled, offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype,