LLM: Adapt transformers models for optimize model SL (#9022)
* LLM: Adapt transformers model for SL
This commit is contained in:
parent
f64257a093
commit
548e4dd5fe
4 changed files with 286 additions and 20 deletions
|
|
@ -24,6 +24,12 @@ from accelerate import init_empty_weights
|
|||
from accelerate.utils import set_module_tensor_to_device
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.utils import extract_local_archive_file, get_local_shard_files
|
||||
import transformers
|
||||
from transformers import PreTrainedModel
|
||||
from .utils.common import MuteHFLogger
|
||||
from .utils.lazy_load_torch import LazyLoadTensors
|
||||
from contextlib import ExitStack, contextmanager
|
||||
|
||||
|
||||
# Simulate the Hugging Face format
|
||||
|
|
@ -37,7 +43,14 @@ def _save_low_bit(self, save_dir, *args, **kwargs):
|
|||
f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
model_path = os.path.join(save_dir, PYTORCH_MODEL_NAME)
|
||||
torch.save(self.state_dict(), model_path, *args, **kwargs)
|
||||
if isinstance(self, PreTrainedModel):
|
||||
# We borrowed this method to adapt to Transformer model cases
|
||||
# as much as possible, and later we may merge these two situations
|
||||
self.save_pretrained(save_dir)
|
||||
else:
|
||||
# TODO: For the lowbit model still larger than 8GB,
|
||||
# save it into shards.
|
||||
torch.save(self.state_dict(), model_path, *args, **kwargs)
|
||||
with open(os.path.join(save_dir, CONFIG_NAME), "w") as json_file:
|
||||
json.dump(self._bigdl_config, json_file)
|
||||
|
||||
|
|
@ -49,14 +62,44 @@ class DisableTorchAllocTensor():
|
|||
def __init__(self) -> None:
|
||||
self._old_torch_load_state_dict = Module.load_state_dict
|
||||
self._old_torch_to_device = Module.to
|
||||
self._old_torch_load_from_state_dict = Module._load_from_state_dict
|
||||
# Chatglm2 init weights manually,
|
||||
# and `skip_init` init on `cpu` by default
|
||||
self._old_skip_init = torch.nn.utils.skip_init
|
||||
|
||||
def __enter__(self):
|
||||
Module.load_state_dict = lambda *args, **kwargs: _IncompatibleKeys([], [])
|
||||
Module._load_from_state_dict = lambda *args, **kwargs: None
|
||||
Module.to = lambda self, *args, **kwargs: self
|
||||
|
||||
def skip_init_on_meta(module_cls, *args, **kwargs):
|
||||
kwargs['device'] = 'meta'
|
||||
return self._old_skip_init(module_cls, *args, **kwargs)
|
||||
torch.nn.utils.skip_init = skip_init_on_meta
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
Module.load_state_dict = self._old_torch_load_state_dict
|
||||
Module._load_from_state_dict = self._old_torch_load_from_state_dict
|
||||
Module.to = self._old_torch_to_device
|
||||
torch.nn.utils.skip_init = self._old_skip_init
|
||||
|
||||
|
||||
class ContextManagers:
|
||||
"""
|
||||
Wrapper for `contextlib.ExitStack` which enters a collection of context managers.
|
||||
Adaptation of `ContextManagers` in the `fastcore` library.
|
||||
"""
|
||||
|
||||
def __init__(self, context_managers):
|
||||
self.context_managers = context_managers
|
||||
self.stack = ExitStack()
|
||||
|
||||
def __enter__(self):
|
||||
for context_manager in self.context_managers:
|
||||
self.stack.enter_context(context_manager)
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
self.stack.__exit__(*args, **kwargs)
|
||||
|
||||
|
||||
def low_bit_sanity_check(model_path):
|
||||
|
|
@ -76,31 +119,43 @@ def low_bit_sanity_check(model_path):
|
|||
return low_bit
|
||||
|
||||
|
||||
def load_low_bit(model_or_creator, model_path, **kwargs):
|
||||
is_creator = not isinstance(model_or_creator, torch.nn.Module) \
|
||||
and callable(model_or_creator)
|
||||
low_bit = low_bit_sanity_check(model_path)
|
||||
@contextmanager
|
||||
def low_memory_init():
|
||||
init_contexts = []
|
||||
init_contexts.extend([init_empty_weights(), DisableTorchAllocTensor()])
|
||||
# Load everything except Tensors' parameters
|
||||
init_contexts.append(LazyLoadTensors())
|
||||
# As we have muted the `torch.load`, this will trigger a key missing warning in hf
|
||||
# but this matters not for we will load again later.
|
||||
init_contexts.append(MuteHFLogger(logger=transformers.modeling_utils.logger))
|
||||
with ContextManagers(init_contexts):
|
||||
yield
|
||||
|
||||
|
||||
def load_low_bit(model, model_path):
|
||||
low_bit = low_bit_sanity_check(model_path)
|
||||
invalidInputError(isinstance(model, torch.nn.Module),
|
||||
"model should be a instance of "
|
||||
f"`torch.nn.Module`, but got {type(model)} at last.")
|
||||
if low_bit:
|
||||
# a creator
|
||||
if is_creator:
|
||||
with init_empty_weights(), DisableTorchAllocTensor():
|
||||
model = model_or_creator(**kwargs)
|
||||
else:
|
||||
model = model_or_creator
|
||||
invalidInputError(isinstance(model, torch.nn.Module),
|
||||
"model_or_creator should be a instance of "
|
||||
"`torch.nn.Module`or a method that returns "
|
||||
f"an instance of `torch.nn.Module`, but got {type(model)} at last.")
|
||||
qtype = ggml_tensor_qtype[low_bit]
|
||||
model = ggml_convert_low_bit(model, qtype=qtype, convert_shape_only=True)
|
||||
|
||||
state_dict = torch.load(os.path.join(model_path, PYTORCH_MODEL_NAME))
|
||||
if is_creator:
|
||||
resolved_archive_file, is_sharded = extract_local_archive_file(model_path, subfolder="")
|
||||
if is_sharded:
|
||||
# For now only shards transformers models
|
||||
# can run in this branch.
|
||||
resolved_archive_file, _ = \
|
||||
get_local_shard_files(model_path,
|
||||
resolved_archive_file,
|
||||
subfolder="")
|
||||
else:
|
||||
resolved_archive_file = [os.path.join(model_path, PYTORCH_MODEL_NAME)]
|
||||
|
||||
for model_file in resolved_archive_file:
|
||||
state_dict = torch.load(model_file)
|
||||
for param_name, param in state_dict.items():
|
||||
set_module_tensor_to_device(model, param_name, "cpu", param)
|
||||
else:
|
||||
model.load_state_dict(state_dict=state_dict)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ WEIGHTS_NAME = "pytorch_model.bin"
|
|||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||
|
||||
|
||||
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant):
|
||||
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant=None):
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
|
||||
|
|
|
|||
193
python/llm/src/bigdl/llm/utils/lazy_load_torch.py
Normal file
193
python/llm/src/bigdl/llm/utils/lazy_load_torch.py
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is adapted from
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L516
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 Georgi Gerganov
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
|
||||
import torch
|
||||
from torch.serialization import StorageType
|
||||
import pickle
|
||||
import zipfile
|
||||
import io
|
||||
from typing import Dict, IO, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
from .common import invalidInputError
|
||||
|
||||
|
||||
item_size = {torch.bfloat16: 2,
|
||||
torch.float16: 2,
|
||||
torch.int: 4,
|
||||
torch.float: 4,
|
||||
torch.float32: 4,
|
||||
torch.int8: 1}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LazyStorage:
|
||||
load: Callable[[int, int], torch.Tensor]
|
||||
kind: StorageType
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LazyTensor:
|
||||
_load: Callable[[], torch.Tensor]
|
||||
shape: list[int]
|
||||
data_type: torch.dtype
|
||||
description: str
|
||||
|
||||
def load(self) -> torch.Tensor:
|
||||
ret = self._load()
|
||||
return ret
|
||||
|
||||
def to(self, data_type):
|
||||
# self.validate_conversion_to(data_type)
|
||||
|
||||
def load() -> torch.Tensor:
|
||||
print(f"to {data_type}")
|
||||
return self.load().to(data_type)
|
||||
return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
|
||||
|
||||
|
||||
def _load(pickle_fp, map_location, picklemoudle, pickle_file='data.pkl', zip_file=None):
|
||||
|
||||
load_module_mapping: Dict[str, str] = {
|
||||
'torch.tensor': 'torch._tensor'
|
||||
}
|
||||
|
||||
class LazyUnpickler(picklemoudle.Unpickler):
|
||||
def __init__(self, fp: IO[bytes], data_base_path: str, zip_file: zipfile.ZipFile):
|
||||
super().__init__(fp)
|
||||
self.data_base_path = data_base_path
|
||||
self.zip_file = zip_file
|
||||
|
||||
def persistent_load(self, pid):
|
||||
data_type = pid[1].dtype
|
||||
filename_stem = pid[2]
|
||||
filename = f'{self.data_base_path}/{filename_stem}'
|
||||
info = self.zip_file.getinfo(filename)
|
||||
|
||||
def load(offset: int, elm_count: int):
|
||||
dtype = data_type
|
||||
fp = self.zip_file.open(info)
|
||||
fp.seek(offset * item_size[dtype])
|
||||
size = elm_count * item_size[dtype]
|
||||
data = fp.read(size)
|
||||
return torch.frombuffer(bytearray(data), dtype=dtype)
|
||||
description = f'storage data_type={data_type} ' \
|
||||
'path-in-zip={filename} path={self.zip_file.filename}'
|
||||
return LazyStorage(load=load, kind=pid[1], description=description)
|
||||
|
||||
@staticmethod
|
||||
def lazy_rebuild_tensor_v2(storage: Any,
|
||||
storage_offset: Any,
|
||||
size: Any,
|
||||
stride: Any,
|
||||
requires_grad: Any,
|
||||
backward_hooks: Any,
|
||||
metadata: Any = None) -> LazyTensor:
|
||||
invalidInputError(isinstance(storage, LazyStorage),
|
||||
"storage should be an instance of class `LazyStorage`, "
|
||||
f"but get {type(storage)}.")
|
||||
|
||||
def load() -> torch.Tensor:
|
||||
elm_count = stride[0] * size[0]
|
||||
return storage.load(storage_offset, elm_count).reshape(size)
|
||||
description = f'pickled storage_offset={storage_offset} in {storage.description}'
|
||||
return LazyTensor(load, list(size), storage.kind.dtype, description)
|
||||
|
||||
@staticmethod
|
||||
def rebuild_from_type_v2(func, new_type, args, state):
|
||||
return func(*args)
|
||||
|
||||
CLASSES: dict[tuple[str, str], Any] = {
|
||||
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
|
||||
('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'),
|
||||
('torch', 'Tensor'): LazyTensor,
|
||||
}
|
||||
|
||||
def find_class(self, mod_name, name):
|
||||
if (mod_name, name) in self.CLASSES:
|
||||
return self.CLASSES[(mod_name, name)]
|
||||
if type(name) is str and 'Storage' in name:
|
||||
try:
|
||||
return StorageType(name)
|
||||
except KeyError:
|
||||
pass
|
||||
mod_name = load_module_mapping.get(mod_name, mod_name)
|
||||
return super().find_class(mod_name, name)
|
||||
|
||||
unpickler = LazyUnpickler(pickle_fp,
|
||||
data_base_path=pickle_file,
|
||||
zip_file=zip_file)
|
||||
result = unpickler.load()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# This can only be used on huggingface transformers loaded from a zip file.
|
||||
def lazyload(
|
||||
f,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
if isinstance(f, io.BufferedIOBase):
|
||||
fp = f
|
||||
else:
|
||||
fp = open(f, 'rb')
|
||||
zf = zipfile.ZipFile(fp)
|
||||
pickle_paths = [name for name in zf.namelist() if name.endswith('.pkl')]
|
||||
invalidInputError(len(pickle_paths) == 1,
|
||||
"There should be only one pickle_paths found, "
|
||||
f"but get {pickle_paths}. ")
|
||||
pickle_fp = zf.open(pickle_paths[0], 'r')
|
||||
state_dict = _load(pickle_fp, None, pickle, pickle_file=pickle_paths[0][:-4], zip_file=zf)
|
||||
return state_dict
|
||||
|
||||
|
||||
class LazyLoadTensors:
|
||||
def __init__(self):
|
||||
self.torch_load = torch.load
|
||||
|
||||
def __enter__(self):
|
||||
torch.load = lazyload
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
torch.load = self.torch_load
|
||||
|
|
@ -22,6 +22,7 @@ import shutil
|
|||
|
||||
from bigdl.llm import llm_convert
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from bigdl.llm.optimize import optimize_model, load_low_bit, low_memory_init
|
||||
|
||||
|
||||
llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH')
|
||||
|
|
@ -87,5 +88,22 @@ class TestConvertModel(TestCase):
|
|||
newModel = AutoModelForCausalLM.load_low_bit(tempdir)
|
||||
assert newModel is not None
|
||||
|
||||
def test_optimize_transformers_llama(self):
|
||||
from transformers import AutoModelForCausalLM as AutoCLM
|
||||
with tempfile.TemporaryDirectory(dir=output_dir) as tempdir:
|
||||
model = AutoCLM.from_pretrained(llama_model_path,
|
||||
torch_dtype="auto",
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True)
|
||||
model = optimize_model(model)
|
||||
model.save_low_bit(tempdir)
|
||||
with low_memory_init():
|
||||
new_model = AutoCLM.from_pretrained(tempdir,
|
||||
torch_dtype="auto",
|
||||
trust_remote_code=True)
|
||||
new_model = load_low_bit(new_model,
|
||||
model_path=tempdir)
|
||||
assert new_model is not None
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
|
|||
Loading…
Reference in a new issue