diff --git a/python/llm/src/bigdl/llm/optimize.py b/python/llm/src/bigdl/llm/optimize.py index 43abef5f..dfb0c344 100644 --- a/python/llm/src/bigdl/llm/optimize.py +++ b/python/llm/src/bigdl/llm/optimize.py @@ -18,6 +18,10 @@ import torch import os import json from .transformers import ggml_convert_low_bit +from torch.nn.modules import Module +from torch.nn.modules.module import _IncompatibleKeys +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError @@ -38,12 +42,27 @@ def _save_low_bit(self, save_dir, *args, **kwargs): json.dump(self._bigdl_config, json_file) -def load_low_bit(model, model_path): - invalidInputError(isinstance(model, torch.nn.Module), - "model should be a instance of `torch.nn.Module`.") +# Under `init_empty_weights()`, we need to disable all actions +# that may lead to any parameter allocation", otherwise may need to error: +# NotImplementedError: Cannot copy out of meta tensor; no data! +class DisableTorchAllocTensor(): + def __init__(self) -> None: + self._old_torch_load_state_dict = Module.load_state_dict + self._old_torch_to_device = Module.to + + def __enter__(self): + Module.load_state_dict = lambda *args, **kwargs: _IncompatibleKeys([], []) + Module.to = lambda self, *args, **kwargs: self + + def __exit__(self, exc_type, exc_value, traceback): + Module.load_state_dict = self._old_torch_load_state_dict + Module.to = self._old_torch_to_device + + +def low_bit_sanity_check(model_path): invalidInputError(os.path.isdir(model_path), "model_path should be a valid directory path.") - invalidInputError(os.path.isdir(os.path.join(model_path, CONFIG_NAME)), + invalidInputError(os.path.isfile(os.path.join(model_path, CONFIG_NAME)), "bigdl_config.json should be under your model directory," "please check your input path.") with open(os.path.join(model_path, CONFIG_NAME), 'r') as f: @@ -54,13 +73,34 @@ def load_low_bit(model, model_path): "Detect this model is not a low-bit model, Please use `optimize_model`" " with low_bit to get a low-bit model , and " " serialize the model using save_low_bit first.") + return low_bit + + +def load_low_bit(model_or_creator, model_path, **kwargs): + is_creator = not isinstance(model_or_creator, torch.nn.Module) \ + and callable(model_or_creator) + low_bit = low_bit_sanity_check(model_path) if low_bit: + # a creator + if is_creator: + with init_empty_weights(), DisableTorchAllocTensor(): + model = model_or_creator(**kwargs) + else: + model = model_or_creator + invalidInputError(isinstance(model, torch.nn.Module), + "model_or_creator should be a instance of " + "`torch.nn.Module`or a method that returns " + f"an instance of `torch.nn.Module`, but got {type(model)} at last.") qtype = ggml_tensor_qtype[low_bit] model = ggml_convert_low_bit(model, qtype=qtype, convert_shape_only=True) state_dict = torch.load(os.path.join(model_path, PYTORCH_MODEL_NAME)) - model.load_state_dict(state_dict=state_dict) + if is_creator: + for param_name, param in state_dict.items(): + set_module_tensor_to_device(model, param_name, "cpu", param) + else: + model.load_state_dict(state_dict=state_dict) return model