LLM: Adapt transformers models for optimize model SL (#9022)

* LLM: Adapt transformers model for SL
This commit is contained in:
Zhao Changmin 2023-10-09 11:13:44 +08:00 committed by GitHub
parent f64257a093
commit 548e4dd5fe
4 changed files with 286 additions and 20 deletions

View file

@ -24,6 +24,12 @@ from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.utils import extract_local_archive_file, get_local_shard_files
import transformers
from transformers import PreTrainedModel
from .utils.common import MuteHFLogger
from .utils.lazy_load_torch import LazyLoadTensors
from contextlib import ExitStack, contextmanager
# Simulate the Hugging Face format
@ -37,7 +43,14 @@ def _save_low_bit(self, save_dir, *args, **kwargs):
f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.")
os.makedirs(save_dir, exist_ok=True)
model_path = os.path.join(save_dir, PYTORCH_MODEL_NAME)
torch.save(self.state_dict(), model_path, *args, **kwargs)
if isinstance(self, PreTrainedModel):
# We borrowed this method to adapt to Transformer model cases
# as much as possible, and later we may merge these two situations
self.save_pretrained(save_dir)
else:
# TODO: For the lowbit model still larger than 8GB,
# save it into shards.
torch.save(self.state_dict(), model_path, *args, **kwargs)
with open(os.path.join(save_dir, CONFIG_NAME), "w") as json_file:
json.dump(self._bigdl_config, json_file)
@ -49,14 +62,44 @@ class DisableTorchAllocTensor():
def __init__(self) -> None:
self._old_torch_load_state_dict = Module.load_state_dict
self._old_torch_to_device = Module.to
self._old_torch_load_from_state_dict = Module._load_from_state_dict
# Chatglm2 init weights manually,
# and `skip_init` init on `cpu` by default
self._old_skip_init = torch.nn.utils.skip_init
def __enter__(self):
Module.load_state_dict = lambda *args, **kwargs: _IncompatibleKeys([], [])
Module._load_from_state_dict = lambda *args, **kwargs: None
Module.to = lambda self, *args, **kwargs: self
def skip_init_on_meta(module_cls, *args, **kwargs):
kwargs['device'] = 'meta'
return self._old_skip_init(module_cls, *args, **kwargs)
torch.nn.utils.skip_init = skip_init_on_meta
def __exit__(self, exc_type, exc_value, traceback):
Module.load_state_dict = self._old_torch_load_state_dict
Module._load_from_state_dict = self._old_torch_load_from_state_dict
Module.to = self._old_torch_to_device
torch.nn.utils.skip_init = self._old_skip_init
class ContextManagers:
"""
Wrapper for `contextlib.ExitStack` which enters a collection of context managers.
Adaptation of `ContextManagers` in the `fastcore` library.
"""
def __init__(self, context_managers):
self.context_managers = context_managers
self.stack = ExitStack()
def __enter__(self):
for context_manager in self.context_managers:
self.stack.enter_context(context_manager)
def __exit__(self, *args, **kwargs):
self.stack.__exit__(*args, **kwargs)
def low_bit_sanity_check(model_path):
@ -76,31 +119,43 @@ def low_bit_sanity_check(model_path):
return low_bit
def load_low_bit(model_or_creator, model_path, **kwargs):
is_creator = not isinstance(model_or_creator, torch.nn.Module) \
and callable(model_or_creator)
low_bit = low_bit_sanity_check(model_path)
@contextmanager
def low_memory_init():
init_contexts = []
init_contexts.extend([init_empty_weights(), DisableTorchAllocTensor()])
# Load everything except Tensors' parameters
init_contexts.append(LazyLoadTensors())
# As we have muted the `torch.load`, this will trigger a key missing warning in hf
# but this matters not for we will load again later.
init_contexts.append(MuteHFLogger(logger=transformers.modeling_utils.logger))
with ContextManagers(init_contexts):
yield
def load_low_bit(model, model_path):
low_bit = low_bit_sanity_check(model_path)
invalidInputError(isinstance(model, torch.nn.Module),
"model should be a instance of "
f"`torch.nn.Module`, but got {type(model)} at last.")
if low_bit:
# a creator
if is_creator:
with init_empty_weights(), DisableTorchAllocTensor():
model = model_or_creator(**kwargs)
else:
model = model_or_creator
invalidInputError(isinstance(model, torch.nn.Module),
"model_or_creator should be a instance of "
"`torch.nn.Module`or a method that returns "
f"an instance of `torch.nn.Module`, but got {type(model)} at last.")
qtype = ggml_tensor_qtype[low_bit]
model = ggml_convert_low_bit(model, qtype=qtype, convert_shape_only=True)
state_dict = torch.load(os.path.join(model_path, PYTORCH_MODEL_NAME))
if is_creator:
resolved_archive_file, is_sharded = extract_local_archive_file(model_path, subfolder="")
if is_sharded:
# For now only shards transformers models
# can run in this branch.
resolved_archive_file, _ = \
get_local_shard_files(model_path,
resolved_archive_file,
subfolder="")
else:
resolved_archive_file = [os.path.join(model_path, PYTORCH_MODEL_NAME)]
for model_file in resolved_archive_file:
state_dict = torch.load(model_file)
for param_name, param in state_dict.items():
set_module_tensor_to_device(model, param_name, "cpu", param)
else:
model.load_state_dict(state_dict=state_dict)
return model

View file

@ -55,7 +55,7 @@ WEIGHTS_NAME = "pytorch_model.bin"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant):
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant=None):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))

View file

@ -0,0 +1,193 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===========================================================================
#
# This file is adapted from
# https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L516
#
# MIT License
#
# Copyright (c) 2023 Georgi Gerganov
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch
from torch.serialization import StorageType
import pickle
import zipfile
import io
from typing import Dict, IO, Any, Callable
from dataclasses import dataclass
from .common import invalidInputError
item_size = {torch.bfloat16: 2,
torch.float16: 2,
torch.int: 4,
torch.float: 4,
torch.float32: 4,
torch.int8: 1}
@dataclass
class LazyStorage:
load: Callable[[int, int], torch.Tensor]
kind: StorageType
description: str
@dataclass
class LazyTensor:
_load: Callable[[], torch.Tensor]
shape: list[int]
data_type: torch.dtype
description: str
def load(self) -> torch.Tensor:
ret = self._load()
return ret
def to(self, data_type):
# self.validate_conversion_to(data_type)
def load() -> torch.Tensor:
print(f"to {data_type}")
return self.load().to(data_type)
return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
def _load(pickle_fp, map_location, picklemoudle, pickle_file='data.pkl', zip_file=None):
load_module_mapping: Dict[str, str] = {
'torch.tensor': 'torch._tensor'
}
class LazyUnpickler(picklemoudle.Unpickler):
def __init__(self, fp: IO[bytes], data_base_path: str, zip_file: zipfile.ZipFile):
super().__init__(fp)
self.data_base_path = data_base_path
self.zip_file = zip_file
def persistent_load(self, pid):
data_type = pid[1].dtype
filename_stem = pid[2]
filename = f'{self.data_base_path}/{filename_stem}'
info = self.zip_file.getinfo(filename)
def load(offset: int, elm_count: int):
dtype = data_type
fp = self.zip_file.open(info)
fp.seek(offset * item_size[dtype])
size = elm_count * item_size[dtype]
data = fp.read(size)
return torch.frombuffer(bytearray(data), dtype=dtype)
description = f'storage data_type={data_type} ' \
'path-in-zip={filename} path={self.zip_file.filename}'
return LazyStorage(load=load, kind=pid[1], description=description)
@staticmethod
def lazy_rebuild_tensor_v2(storage: Any,
storage_offset: Any,
size: Any,
stride: Any,
requires_grad: Any,
backward_hooks: Any,
metadata: Any = None) -> LazyTensor:
invalidInputError(isinstance(storage, LazyStorage),
"storage should be an instance of class `LazyStorage`, "
f"but get {type(storage)}.")
def load() -> torch.Tensor:
elm_count = stride[0] * size[0]
return storage.load(storage_offset, elm_count).reshape(size)
description = f'pickled storage_offset={storage_offset} in {storage.description}'
return LazyTensor(load, list(size), storage.kind.dtype, description)
@staticmethod
def rebuild_from_type_v2(func, new_type, args, state):
return func(*args)
CLASSES: dict[tuple[str, str], Any] = {
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'),
('torch', 'Tensor'): LazyTensor,
}
def find_class(self, mod_name, name):
if (mod_name, name) in self.CLASSES:
return self.CLASSES[(mod_name, name)]
if type(name) is str and 'Storage' in name:
try:
return StorageType(name)
except KeyError:
pass
mod_name = load_module_mapping.get(mod_name, mod_name)
return super().find_class(mod_name, name)
unpickler = LazyUnpickler(pickle_fp,
data_base_path=pickle_file,
zip_file=zip_file)
result = unpickler.load()
return result
# This can only be used on huggingface transformers loaded from a zip file.
def lazyload(
f,
*args,
**kwargs
):
if isinstance(f, io.BufferedIOBase):
fp = f
else:
fp = open(f, 'rb')
zf = zipfile.ZipFile(fp)
pickle_paths = [name for name in zf.namelist() if name.endswith('.pkl')]
invalidInputError(len(pickle_paths) == 1,
"There should be only one pickle_paths found, "
f"but get {pickle_paths}. ")
pickle_fp = zf.open(pickle_paths[0], 'r')
state_dict = _load(pickle_fp, None, pickle, pickle_file=pickle_paths[0][:-4], zip_file=zf)
return state_dict
class LazyLoadTensors:
def __init__(self):
self.torch_load = torch.load
def __enter__(self):
torch.load = lazyload
def __exit__(self, exc_type, exc_value, traceback):
torch.load = self.torch_load

View file

@ -22,6 +22,7 @@ import shutil
from bigdl.llm import llm_convert
from bigdl.llm.transformers import AutoModelForCausalLM
from bigdl.llm.optimize import optimize_model, load_low_bit, low_memory_init
llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH')
@ -87,5 +88,22 @@ class TestConvertModel(TestCase):
newModel = AutoModelForCausalLM.load_low_bit(tempdir)
assert newModel is not None
def test_optimize_transformers_llama(self):
from transformers import AutoModelForCausalLM as AutoCLM
with tempfile.TemporaryDirectory(dir=output_dir) as tempdir:
model = AutoCLM.from_pretrained(llama_model_path,
torch_dtype="auto",
low_cpu_mem_usage=True,
trust_remote_code=True)
model = optimize_model(model)
model.save_low_bit(tempdir)
with low_memory_init():
new_model = AutoCLM.from_pretrained(tempdir,
torch_dtype="auto",
trust_remote_code=True)
new_model = load_low_bit(new_model,
model_path=tempdir)
assert new_model is not None
if __name__ == '__main__':
pytest.main([__file__])