LLM: Apply low_cpu_mem_usage algorithm on optimize_model API (#8987)

* low_cpu_mem_usage
This commit is contained in:
Zhao Changmin 2023-09-18 21:41:42 +08:00 committed by GitHub
parent 8299b68fea
commit 2a05581da7

View file

@ -18,6 +18,10 @@ import torch
import os
import json
from .transformers import ggml_convert_low_bit
from torch.nn.modules import Module
from torch.nn.modules.module import _IncompatibleKeys
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError
@ -38,12 +42,27 @@ def _save_low_bit(self, save_dir, *args, **kwargs):
json.dump(self._bigdl_config, json_file)
def load_low_bit(model, model_path):
invalidInputError(isinstance(model, torch.nn.Module),
"model should be a instance of `torch.nn.Module`.")
# Under `init_empty_weights()`, we need to disable all actions
# that may lead to any parameter allocation", otherwise may need to error:
# NotImplementedError: Cannot copy out of meta tensor; no data!
class DisableTorchAllocTensor():
def __init__(self) -> None:
self._old_torch_load_state_dict = Module.load_state_dict
self._old_torch_to_device = Module.to
def __enter__(self):
Module.load_state_dict = lambda *args, **kwargs: _IncompatibleKeys([], [])
Module.to = lambda self, *args, **kwargs: self
def __exit__(self, exc_type, exc_value, traceback):
Module.load_state_dict = self._old_torch_load_state_dict
Module.to = self._old_torch_to_device
def low_bit_sanity_check(model_path):
invalidInputError(os.path.isdir(model_path),
"model_path should be a valid directory path.")
invalidInputError(os.path.isdir(os.path.join(model_path, CONFIG_NAME)),
invalidInputError(os.path.isfile(os.path.join(model_path, CONFIG_NAME)),
"bigdl_config.json should be under your model directory,"
"please check your input path.")
with open(os.path.join(model_path, CONFIG_NAME), 'r') as f:
@ -54,12 +73,33 @@ def load_low_bit(model, model_path):
"Detect this model is not a low-bit model, Please use `optimize_model`"
" with low_bit to get a low-bit model , and "
" serialize the model using save_low_bit first.")
return low_bit
def load_low_bit(model_or_creator, model_path, **kwargs):
is_creator = not isinstance(model_or_creator, torch.nn.Module) \
and callable(model_or_creator)
low_bit = low_bit_sanity_check(model_path)
if low_bit:
# a creator
if is_creator:
with init_empty_weights(), DisableTorchAllocTensor():
model = model_or_creator(**kwargs)
else:
model = model_or_creator
invalidInputError(isinstance(model, torch.nn.Module),
"model_or_creator should be a instance of "
"`torch.nn.Module`or a method that returns "
f"an instance of `torch.nn.Module`, but got {type(model)} at last.")
qtype = ggml_tensor_qtype[low_bit]
model = ggml_convert_low_bit(model, qtype=qtype, convert_shape_only=True)
state_dict = torch.load(os.path.join(model_path, PYTORCH_MODEL_NAME))
if is_creator:
for param_name, param in state_dict.items():
set_module_tensor_to_device(model, param_name, "cpu", param)
else:
model.load_state_dict(state_dict=state_dict)
return model