LLM: Apply low_cpu_mem_usage algorithm on optimize_model API (#8987)
* low_cpu_mem_usage
This commit is contained in:
parent
8299b68fea
commit
2a05581da7
1 changed files with 45 additions and 5 deletions
|
|
@ -18,6 +18,10 @@ import torch
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from .transformers import ggml_convert_low_bit
|
from .transformers import ggml_convert_low_bit
|
||||||
|
from torch.nn.modules import Module
|
||||||
|
from torch.nn.modules.module import _IncompatibleKeys
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
@ -38,12 +42,27 @@ def _save_low_bit(self, save_dir, *args, **kwargs):
|
||||||
json.dump(self._bigdl_config, json_file)
|
json.dump(self._bigdl_config, json_file)
|
||||||
|
|
||||||
|
|
||||||
def load_low_bit(model, model_path):
|
# Under `init_empty_weights()`, we need to disable all actions
|
||||||
invalidInputError(isinstance(model, torch.nn.Module),
|
# that may lead to any parameter allocation", otherwise may need to error:
|
||||||
"model should be a instance of `torch.nn.Module`.")
|
# NotImplementedError: Cannot copy out of meta tensor; no data!
|
||||||
|
class DisableTorchAllocTensor():
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._old_torch_load_state_dict = Module.load_state_dict
|
||||||
|
self._old_torch_to_device = Module.to
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
Module.load_state_dict = lambda *args, **kwargs: _IncompatibleKeys([], [])
|
||||||
|
Module.to = lambda self, *args, **kwargs: self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
Module.load_state_dict = self._old_torch_load_state_dict
|
||||||
|
Module.to = self._old_torch_to_device
|
||||||
|
|
||||||
|
|
||||||
|
def low_bit_sanity_check(model_path):
|
||||||
invalidInputError(os.path.isdir(model_path),
|
invalidInputError(os.path.isdir(model_path),
|
||||||
"model_path should be a valid directory path.")
|
"model_path should be a valid directory path.")
|
||||||
invalidInputError(os.path.isdir(os.path.join(model_path, CONFIG_NAME)),
|
invalidInputError(os.path.isfile(os.path.join(model_path, CONFIG_NAME)),
|
||||||
"bigdl_config.json should be under your model directory,"
|
"bigdl_config.json should be under your model directory,"
|
||||||
"please check your input path.")
|
"please check your input path.")
|
||||||
with open(os.path.join(model_path, CONFIG_NAME), 'r') as f:
|
with open(os.path.join(model_path, CONFIG_NAME), 'r') as f:
|
||||||
|
|
@ -54,12 +73,33 @@ def load_low_bit(model, model_path):
|
||||||
"Detect this model is not a low-bit model, Please use `optimize_model`"
|
"Detect this model is not a low-bit model, Please use `optimize_model`"
|
||||||
" with low_bit to get a low-bit model , and "
|
" with low_bit to get a low-bit model , and "
|
||||||
" serialize the model using save_low_bit first.")
|
" serialize the model using save_low_bit first.")
|
||||||
|
return low_bit
|
||||||
|
|
||||||
|
|
||||||
|
def load_low_bit(model_or_creator, model_path, **kwargs):
|
||||||
|
is_creator = not isinstance(model_or_creator, torch.nn.Module) \
|
||||||
|
and callable(model_or_creator)
|
||||||
|
low_bit = low_bit_sanity_check(model_path)
|
||||||
|
|
||||||
if low_bit:
|
if low_bit:
|
||||||
|
# a creator
|
||||||
|
if is_creator:
|
||||||
|
with init_empty_weights(), DisableTorchAllocTensor():
|
||||||
|
model = model_or_creator(**kwargs)
|
||||||
|
else:
|
||||||
|
model = model_or_creator
|
||||||
|
invalidInputError(isinstance(model, torch.nn.Module),
|
||||||
|
"model_or_creator should be a instance of "
|
||||||
|
"`torch.nn.Module`or a method that returns "
|
||||||
|
f"an instance of `torch.nn.Module`, but got {type(model)} at last.")
|
||||||
qtype = ggml_tensor_qtype[low_bit]
|
qtype = ggml_tensor_qtype[low_bit]
|
||||||
model = ggml_convert_low_bit(model, qtype=qtype, convert_shape_only=True)
|
model = ggml_convert_low_bit(model, qtype=qtype, convert_shape_only=True)
|
||||||
|
|
||||||
state_dict = torch.load(os.path.join(model_path, PYTORCH_MODEL_NAME))
|
state_dict = torch.load(os.path.join(model_path, PYTORCH_MODEL_NAME))
|
||||||
|
if is_creator:
|
||||||
|
for param_name, param in state_dict.items():
|
||||||
|
set_module_tensor_to_device(model, param_name, "cpu", param)
|
||||||
|
else:
|
||||||
model.load_state_dict(state_dict=state_dict)
|
model.load_state_dict(state_dict=state_dict)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue