LLM: Apply low_cpu_mem_usage algorithm on optimize_model API (#8987)
* low_cpu_mem_usage
This commit is contained in:
parent
8299b68fea
commit
2a05581da7
1 changed files with 45 additions and 5 deletions
|
|
@ -18,6 +18,10 @@ import torch
|
|||
import os
|
||||
import json
|
||||
from .transformers import ggml_convert_low_bit
|
||||
from torch.nn.modules import Module
|
||||
from torch.nn.modules.module import _IncompatibleKeys
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
|
|
@ -38,12 +42,27 @@ def _save_low_bit(self, save_dir, *args, **kwargs):
|
|||
json.dump(self._bigdl_config, json_file)
|
||||
|
||||
|
||||
def load_low_bit(model, model_path):
|
||||
invalidInputError(isinstance(model, torch.nn.Module),
|
||||
"model should be a instance of `torch.nn.Module`.")
|
||||
# Under `init_empty_weights()`, we need to disable all actions
|
||||
# that may lead to any parameter allocation", otherwise may need to error:
|
||||
# NotImplementedError: Cannot copy out of meta tensor; no data!
|
||||
class DisableTorchAllocTensor():
|
||||
def __init__(self) -> None:
|
||||
self._old_torch_load_state_dict = Module.load_state_dict
|
||||
self._old_torch_to_device = Module.to
|
||||
|
||||
def __enter__(self):
|
||||
Module.load_state_dict = lambda *args, **kwargs: _IncompatibleKeys([], [])
|
||||
Module.to = lambda self, *args, **kwargs: self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
Module.load_state_dict = self._old_torch_load_state_dict
|
||||
Module.to = self._old_torch_to_device
|
||||
|
||||
|
||||
def low_bit_sanity_check(model_path):
|
||||
invalidInputError(os.path.isdir(model_path),
|
||||
"model_path should be a valid directory path.")
|
||||
invalidInputError(os.path.isdir(os.path.join(model_path, CONFIG_NAME)),
|
||||
invalidInputError(os.path.isfile(os.path.join(model_path, CONFIG_NAME)),
|
||||
"bigdl_config.json should be under your model directory,"
|
||||
"please check your input path.")
|
||||
with open(os.path.join(model_path, CONFIG_NAME), 'r') as f:
|
||||
|
|
@ -54,13 +73,34 @@ def load_low_bit(model, model_path):
|
|||
"Detect this model is not a low-bit model, Please use `optimize_model`"
|
||||
" with low_bit to get a low-bit model , and "
|
||||
" serialize the model using save_low_bit first.")
|
||||
return low_bit
|
||||
|
||||
|
||||
def load_low_bit(model_or_creator, model_path, **kwargs):
|
||||
is_creator = not isinstance(model_or_creator, torch.nn.Module) \
|
||||
and callable(model_or_creator)
|
||||
low_bit = low_bit_sanity_check(model_path)
|
||||
|
||||
if low_bit:
|
||||
# a creator
|
||||
if is_creator:
|
||||
with init_empty_weights(), DisableTorchAllocTensor():
|
||||
model = model_or_creator(**kwargs)
|
||||
else:
|
||||
model = model_or_creator
|
||||
invalidInputError(isinstance(model, torch.nn.Module),
|
||||
"model_or_creator should be a instance of "
|
||||
"`torch.nn.Module`or a method that returns "
|
||||
f"an instance of `torch.nn.Module`, but got {type(model)} at last.")
|
||||
qtype = ggml_tensor_qtype[low_bit]
|
||||
model = ggml_convert_low_bit(model, qtype=qtype, convert_shape_only=True)
|
||||
|
||||
state_dict = torch.load(os.path.join(model_path, PYTORCH_MODEL_NAME))
|
||||
model.load_state_dict(state_dict=state_dict)
|
||||
if is_creator:
|
||||
for param_name, param in state_dict.items():
|
||||
set_module_tensor_to_device(model, param_name, "cpu", param)
|
||||
else:
|
||||
model.load_state_dict(state_dict=state_dict)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue