LLM: Adapt transformers models for optimize model SL (#9022)
* LLM: Adapt transformers model for SL
This commit is contained in:
parent
f64257a093
commit
548e4dd5fe
4 changed files with 286 additions and 20 deletions
|
|
@ -24,6 +24,12 @@ from accelerate import init_empty_weights
|
||||||
from accelerate.utils import set_module_tensor_to_device
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
from bigdl.llm.transformers.utils import extract_local_archive_file, get_local_shard_files
|
||||||
|
import transformers
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
from .utils.common import MuteHFLogger
|
||||||
|
from .utils.lazy_load_torch import LazyLoadTensors
|
||||||
|
from contextlib import ExitStack, contextmanager
|
||||||
|
|
||||||
|
|
||||||
# Simulate the Hugging Face format
|
# Simulate the Hugging Face format
|
||||||
|
|
@ -37,7 +43,14 @@ def _save_low_bit(self, save_dir, *args, **kwargs):
|
||||||
f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.")
|
f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.")
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
model_path = os.path.join(save_dir, PYTORCH_MODEL_NAME)
|
model_path = os.path.join(save_dir, PYTORCH_MODEL_NAME)
|
||||||
torch.save(self.state_dict(), model_path, *args, **kwargs)
|
if isinstance(self, PreTrainedModel):
|
||||||
|
# We borrowed this method to adapt to Transformer model cases
|
||||||
|
# as much as possible, and later we may merge these two situations
|
||||||
|
self.save_pretrained(save_dir)
|
||||||
|
else:
|
||||||
|
# TODO: For the lowbit model still larger than 8GB,
|
||||||
|
# save it into shards.
|
||||||
|
torch.save(self.state_dict(), model_path, *args, **kwargs)
|
||||||
with open(os.path.join(save_dir, CONFIG_NAME), "w") as json_file:
|
with open(os.path.join(save_dir, CONFIG_NAME), "w") as json_file:
|
||||||
json.dump(self._bigdl_config, json_file)
|
json.dump(self._bigdl_config, json_file)
|
||||||
|
|
||||||
|
|
@ -49,14 +62,44 @@ class DisableTorchAllocTensor():
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._old_torch_load_state_dict = Module.load_state_dict
|
self._old_torch_load_state_dict = Module.load_state_dict
|
||||||
self._old_torch_to_device = Module.to
|
self._old_torch_to_device = Module.to
|
||||||
|
self._old_torch_load_from_state_dict = Module._load_from_state_dict
|
||||||
|
# Chatglm2 init weights manually,
|
||||||
|
# and `skip_init` init on `cpu` by default
|
||||||
|
self._old_skip_init = torch.nn.utils.skip_init
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
Module.load_state_dict = lambda *args, **kwargs: _IncompatibleKeys([], [])
|
Module.load_state_dict = lambda *args, **kwargs: _IncompatibleKeys([], [])
|
||||||
|
Module._load_from_state_dict = lambda *args, **kwargs: None
|
||||||
Module.to = lambda self, *args, **kwargs: self
|
Module.to = lambda self, *args, **kwargs: self
|
||||||
|
|
||||||
|
def skip_init_on_meta(module_cls, *args, **kwargs):
|
||||||
|
kwargs['device'] = 'meta'
|
||||||
|
return self._old_skip_init(module_cls, *args, **kwargs)
|
||||||
|
torch.nn.utils.skip_init = skip_init_on_meta
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
Module.load_state_dict = self._old_torch_load_state_dict
|
Module.load_state_dict = self._old_torch_load_state_dict
|
||||||
|
Module._load_from_state_dict = self._old_torch_load_from_state_dict
|
||||||
Module.to = self._old_torch_to_device
|
Module.to = self._old_torch_to_device
|
||||||
|
torch.nn.utils.skip_init = self._old_skip_init
|
||||||
|
|
||||||
|
|
||||||
|
class ContextManagers:
|
||||||
|
"""
|
||||||
|
Wrapper for `contextlib.ExitStack` which enters a collection of context managers.
|
||||||
|
Adaptation of `ContextManagers` in the `fastcore` library.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, context_managers):
|
||||||
|
self.context_managers = context_managers
|
||||||
|
self.stack = ExitStack()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
for context_manager in self.context_managers:
|
||||||
|
self.stack.enter_context(context_manager)
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
self.stack.__exit__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def low_bit_sanity_check(model_path):
|
def low_bit_sanity_check(model_path):
|
||||||
|
|
@ -76,31 +119,43 @@ def low_bit_sanity_check(model_path):
|
||||||
return low_bit
|
return low_bit
|
||||||
|
|
||||||
|
|
||||||
def load_low_bit(model_or_creator, model_path, **kwargs):
|
@contextmanager
|
||||||
is_creator = not isinstance(model_or_creator, torch.nn.Module) \
|
def low_memory_init():
|
||||||
and callable(model_or_creator)
|
init_contexts = []
|
||||||
low_bit = low_bit_sanity_check(model_path)
|
init_contexts.extend([init_empty_weights(), DisableTorchAllocTensor()])
|
||||||
|
# Load everything except Tensors' parameters
|
||||||
|
init_contexts.append(LazyLoadTensors())
|
||||||
|
# As we have muted the `torch.load`, this will trigger a key missing warning in hf
|
||||||
|
# but this matters not for we will load again later.
|
||||||
|
init_contexts.append(MuteHFLogger(logger=transformers.modeling_utils.logger))
|
||||||
|
with ContextManagers(init_contexts):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def load_low_bit(model, model_path):
|
||||||
|
low_bit = low_bit_sanity_check(model_path)
|
||||||
|
invalidInputError(isinstance(model, torch.nn.Module),
|
||||||
|
"model should be a instance of "
|
||||||
|
f"`torch.nn.Module`, but got {type(model)} at last.")
|
||||||
if low_bit:
|
if low_bit:
|
||||||
# a creator
|
|
||||||
if is_creator:
|
|
||||||
with init_empty_weights(), DisableTorchAllocTensor():
|
|
||||||
model = model_or_creator(**kwargs)
|
|
||||||
else:
|
|
||||||
model = model_or_creator
|
|
||||||
invalidInputError(isinstance(model, torch.nn.Module),
|
|
||||||
"model_or_creator should be a instance of "
|
|
||||||
"`torch.nn.Module`or a method that returns "
|
|
||||||
f"an instance of `torch.nn.Module`, but got {type(model)} at last.")
|
|
||||||
qtype = ggml_tensor_qtype[low_bit]
|
qtype = ggml_tensor_qtype[low_bit]
|
||||||
model = ggml_convert_low_bit(model, qtype=qtype, convert_shape_only=True)
|
model = ggml_convert_low_bit(model, qtype=qtype, convert_shape_only=True)
|
||||||
|
|
||||||
state_dict = torch.load(os.path.join(model_path, PYTORCH_MODEL_NAME))
|
resolved_archive_file, is_sharded = extract_local_archive_file(model_path, subfolder="")
|
||||||
if is_creator:
|
if is_sharded:
|
||||||
|
# For now only shards transformers models
|
||||||
|
# can run in this branch.
|
||||||
|
resolved_archive_file, _ = \
|
||||||
|
get_local_shard_files(model_path,
|
||||||
|
resolved_archive_file,
|
||||||
|
subfolder="")
|
||||||
|
else:
|
||||||
|
resolved_archive_file = [os.path.join(model_path, PYTORCH_MODEL_NAME)]
|
||||||
|
|
||||||
|
for model_file in resolved_archive_file:
|
||||||
|
state_dict = torch.load(model_file)
|
||||||
for param_name, param in state_dict.items():
|
for param_name, param in state_dict.items():
|
||||||
set_module_tensor_to_device(model, param_name, "cpu", param)
|
set_module_tensor_to_device(model, param_name, "cpu", param)
|
||||||
else:
|
|
||||||
model.load_state_dict(state_dict=state_dict)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||||
|
|
||||||
|
|
||||||
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant):
|
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant=None):
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
if os.path.isfile(
|
if os.path.isfile(
|
||||||
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
|
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
|
||||||
|
|
|
||||||
193
python/llm/src/bigdl/llm/utils/lazy_load_torch.py
Normal file
193
python/llm/src/bigdl/llm/utils/lazy_load_torch.py
Normal file
|
|
@ -0,0 +1,193 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ===========================================================================
|
||||||
|
#
|
||||||
|
# This file is adapted from
|
||||||
|
# https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L516
|
||||||
|
#
|
||||||
|
# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) 2023 Georgi Gerganov
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.serialization import StorageType
|
||||||
|
import pickle
|
||||||
|
import zipfile
|
||||||
|
import io
|
||||||
|
from typing import Dict, IO, Any, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from .common import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
|
item_size = {torch.bfloat16: 2,
|
||||||
|
torch.float16: 2,
|
||||||
|
torch.int: 4,
|
||||||
|
torch.float: 4,
|
||||||
|
torch.float32: 4,
|
||||||
|
torch.int8: 1}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LazyStorage:
|
||||||
|
load: Callable[[int, int], torch.Tensor]
|
||||||
|
kind: StorageType
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LazyTensor:
|
||||||
|
_load: Callable[[], torch.Tensor]
|
||||||
|
shape: list[int]
|
||||||
|
data_type: torch.dtype
|
||||||
|
description: str
|
||||||
|
|
||||||
|
def load(self) -> torch.Tensor:
|
||||||
|
ret = self._load()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def to(self, data_type):
|
||||||
|
# self.validate_conversion_to(data_type)
|
||||||
|
|
||||||
|
def load() -> torch.Tensor:
|
||||||
|
print(f"to {data_type}")
|
||||||
|
return self.load().to(data_type)
|
||||||
|
return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
|
||||||
|
|
||||||
|
|
||||||
|
def _load(pickle_fp, map_location, picklemoudle, pickle_file='data.pkl', zip_file=None):
|
||||||
|
|
||||||
|
load_module_mapping: Dict[str, str] = {
|
||||||
|
'torch.tensor': 'torch._tensor'
|
||||||
|
}
|
||||||
|
|
||||||
|
class LazyUnpickler(picklemoudle.Unpickler):
|
||||||
|
def __init__(self, fp: IO[bytes], data_base_path: str, zip_file: zipfile.ZipFile):
|
||||||
|
super().__init__(fp)
|
||||||
|
self.data_base_path = data_base_path
|
||||||
|
self.zip_file = zip_file
|
||||||
|
|
||||||
|
def persistent_load(self, pid):
|
||||||
|
data_type = pid[1].dtype
|
||||||
|
filename_stem = pid[2]
|
||||||
|
filename = f'{self.data_base_path}/{filename_stem}'
|
||||||
|
info = self.zip_file.getinfo(filename)
|
||||||
|
|
||||||
|
def load(offset: int, elm_count: int):
|
||||||
|
dtype = data_type
|
||||||
|
fp = self.zip_file.open(info)
|
||||||
|
fp.seek(offset * item_size[dtype])
|
||||||
|
size = elm_count * item_size[dtype]
|
||||||
|
data = fp.read(size)
|
||||||
|
return torch.frombuffer(bytearray(data), dtype=dtype)
|
||||||
|
description = f'storage data_type={data_type} ' \
|
||||||
|
'path-in-zip={filename} path={self.zip_file.filename}'
|
||||||
|
return LazyStorage(load=load, kind=pid[1], description=description)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def lazy_rebuild_tensor_v2(storage: Any,
|
||||||
|
storage_offset: Any,
|
||||||
|
size: Any,
|
||||||
|
stride: Any,
|
||||||
|
requires_grad: Any,
|
||||||
|
backward_hooks: Any,
|
||||||
|
metadata: Any = None) -> LazyTensor:
|
||||||
|
invalidInputError(isinstance(storage, LazyStorage),
|
||||||
|
"storage should be an instance of class `LazyStorage`, "
|
||||||
|
f"but get {type(storage)}.")
|
||||||
|
|
||||||
|
def load() -> torch.Tensor:
|
||||||
|
elm_count = stride[0] * size[0]
|
||||||
|
return storage.load(storage_offset, elm_count).reshape(size)
|
||||||
|
description = f'pickled storage_offset={storage_offset} in {storage.description}'
|
||||||
|
return LazyTensor(load, list(size), storage.kind.dtype, description)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def rebuild_from_type_v2(func, new_type, args, state):
|
||||||
|
return func(*args)
|
||||||
|
|
||||||
|
CLASSES: dict[tuple[str, str], Any] = {
|
||||||
|
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
|
||||||
|
('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'),
|
||||||
|
('torch', 'Tensor'): LazyTensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
def find_class(self, mod_name, name):
|
||||||
|
if (mod_name, name) in self.CLASSES:
|
||||||
|
return self.CLASSES[(mod_name, name)]
|
||||||
|
if type(name) is str and 'Storage' in name:
|
||||||
|
try:
|
||||||
|
return StorageType(name)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
mod_name = load_module_mapping.get(mod_name, mod_name)
|
||||||
|
return super().find_class(mod_name, name)
|
||||||
|
|
||||||
|
unpickler = LazyUnpickler(pickle_fp,
|
||||||
|
data_base_path=pickle_file,
|
||||||
|
zip_file=zip_file)
|
||||||
|
result = unpickler.load()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# This can only be used on huggingface transformers loaded from a zip file.
|
||||||
|
def lazyload(
|
||||||
|
f,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if isinstance(f, io.BufferedIOBase):
|
||||||
|
fp = f
|
||||||
|
else:
|
||||||
|
fp = open(f, 'rb')
|
||||||
|
zf = zipfile.ZipFile(fp)
|
||||||
|
pickle_paths = [name for name in zf.namelist() if name.endswith('.pkl')]
|
||||||
|
invalidInputError(len(pickle_paths) == 1,
|
||||||
|
"There should be only one pickle_paths found, "
|
||||||
|
f"but get {pickle_paths}. ")
|
||||||
|
pickle_fp = zf.open(pickle_paths[0], 'r')
|
||||||
|
state_dict = _load(pickle_fp, None, pickle, pickle_file=pickle_paths[0][:-4], zip_file=zf)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
class LazyLoadTensors:
|
||||||
|
def __init__(self):
|
||||||
|
self.torch_load = torch.load
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
torch.load = lazyload
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
torch.load = self.torch_load
|
||||||
|
|
@ -22,6 +22,7 @@ import shutil
|
||||||
|
|
||||||
from bigdl.llm import llm_convert
|
from bigdl.llm import llm_convert
|
||||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||||
|
from bigdl.llm.optimize import optimize_model, load_low_bit, low_memory_init
|
||||||
|
|
||||||
|
|
||||||
llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH')
|
llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH')
|
||||||
|
|
@ -87,5 +88,22 @@ class TestConvertModel(TestCase):
|
||||||
newModel = AutoModelForCausalLM.load_low_bit(tempdir)
|
newModel = AutoModelForCausalLM.load_low_bit(tempdir)
|
||||||
assert newModel is not None
|
assert newModel is not None
|
||||||
|
|
||||||
|
def test_optimize_transformers_llama(self):
|
||||||
|
from transformers import AutoModelForCausalLM as AutoCLM
|
||||||
|
with tempfile.TemporaryDirectory(dir=output_dir) as tempdir:
|
||||||
|
model = AutoCLM.from_pretrained(llama_model_path,
|
||||||
|
torch_dtype="auto",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
trust_remote_code=True)
|
||||||
|
model = optimize_model(model)
|
||||||
|
model.save_low_bit(tempdir)
|
||||||
|
with low_memory_init():
|
||||||
|
new_model = AutoCLM.from_pretrained(tempdir,
|
||||||
|
torch_dtype="auto",
|
||||||
|
trust_remote_code=True)
|
||||||
|
new_model = load_low_bit(new_model,
|
||||||
|
model_path=tempdir)
|
||||||
|
assert new_model is not None
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue