LLM: transformer int4 save and load (#8462)
* LLM: transformer int4 save and load
This commit is contained in:
parent
04f2f04410
commit
81d655cda9
5 changed files with 164 additions and 14 deletions
|
|
@ -41,7 +41,8 @@ from bigdl.llm.transformers.linear_int4 import LinearInt4, ParamsInt4
|
|||
import warnings
|
||||
|
||||
|
||||
def _replace_with_int4_linear(model, modules_to_not_convert=None, current_key_name=None):
|
||||
def _replace_with_int4_linear(model, modules_to_not_convert=None,
|
||||
current_key_name=None, convert_shape_only=False):
|
||||
has_been_replaced = False
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
|
|
@ -59,10 +60,12 @@ def _replace_with_int4_linear(model, modules_to_not_convert=None, current_key_na
|
|||
)
|
||||
|
||||
# Copy the weights
|
||||
new_linear._parameters['weight'] = ParamsInt4(data=module.weight.data,
|
||||
requires_grad=False,
|
||||
quantized=False,
|
||||
_shape=None).to("cpu")
|
||||
paramsint4 = ParamsInt4(data=module.weight.data,
|
||||
requires_grad=False,
|
||||
quantized=False,
|
||||
convert_shape_only=convert_shape_only,
|
||||
_shape=None).to("cpu")
|
||||
new_linear._parameters['weight'] = paramsint4
|
||||
if module.bias is not None:
|
||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data).to("cpu")
|
||||
|
||||
|
|
@ -83,10 +86,10 @@ def _replace_with_int4_linear(model, modules_to_not_convert=None, current_key_na
|
|||
return model, has_been_replaced
|
||||
|
||||
|
||||
def ggml_convert_int4(model):
|
||||
def ggml_convert_int4(model, convert_shape_only=False):
|
||||
modules_to_not_convert = [] # ["lm_head"]
|
||||
model, has_been_replaced = _replace_with_int4_linear(
|
||||
model, modules_to_not_convert, None
|
||||
model, modules_to_not_convert, None, convert_shape_only=convert_shape_only
|
||||
)
|
||||
if not has_been_replaced:
|
||||
warnings.warn(
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ scale_size_in_bytes = 4
|
|||
block_size_in_bytes = QK // 2 + scale_size_in_bytes
|
||||
|
||||
|
||||
def ggml_convert_int4(tensor: torch.Tensor):
|
||||
def ggml_convert_int4(tensor: torch.Tensor, convert_shape_only=False):
|
||||
|
||||
invalidInputError(tensor.dtype == torch.float,
|
||||
"Input tensor must be float32")
|
||||
|
|
@ -79,12 +79,14 @@ def ggml_convert_int4(tensor: torch.Tensor):
|
|||
|
||||
hist = (ctypes.c_int64 * 16)()
|
||||
|
||||
ggml.ggml_quantize_q4_0(src, dst, n, k, hist)
|
||||
if not convert_shape_only:
|
||||
ggml.ggml_quantize_q4_0(src, dst, n, k, hist)
|
||||
return dst_tensor
|
||||
|
||||
|
||||
class ParamsInt4(torch.nn.Parameter):
|
||||
def __new__(cls, data=None, requires_grad=True, old_data=None, quantized=False, _shape=None):
|
||||
def __new__(cls, data=None, requires_grad=True, old_data=None,
|
||||
quantized=False, _shape=None, convert_shape_only=False):
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
|
||||
|
|
@ -92,13 +94,14 @@ class ParamsInt4(torch.nn.Parameter):
|
|||
self.data = data
|
||||
self.quantized = quantized
|
||||
self._shape = _shape
|
||||
self.convert_shape_only = convert_shape_only
|
||||
return self
|
||||
|
||||
def quantize(self, device):
|
||||
if not self.quantized:
|
||||
w = self.data.contiguous().float()
|
||||
# self.old_data = self.data
|
||||
w_4bit = ggml_convert_int4(w)
|
||||
w_4bit = ggml_convert_int4(w, convert_shape_only=self.convert_shape_only)
|
||||
self.data = w_4bit
|
||||
self.quantized = True
|
||||
self._shape = w.shape
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@
|
|||
#
|
||||
|
||||
import transformers
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from .utils import extract_local_archive_file, load_state_dict, load
|
||||
|
||||
|
||||
class _BaseAutoModelClass:
|
||||
|
|
@ -29,12 +30,51 @@ class _BaseAutoModelClass:
|
|||
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
||||
if load_in_4bit:
|
||||
kwargs["low_cpu_mem_usage"] = True
|
||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||
|
||||
if load_in_4bit:
|
||||
subfolder = kwargs.get("subfolder", "")
|
||||
variant = kwargs.get("variant", None)
|
||||
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
||||
if len(args) == 0 else args[0]
|
||||
|
||||
# For huggingface transformers cls.HF_Model.from_pretrained could only restore the model
|
||||
# in the original format, which is not quantized,
|
||||
# we can convert the model to quantized later.
|
||||
model = None
|
||||
|
||||
# Read bigdl_transformers_int4 from config.json
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
||||
|
||||
bigdl_transformers_int4 = config_dict.pop("bigdl_transformers_int4", False)
|
||||
if bigdl_transformers_int4:
|
||||
# Avoid KeyError
|
||||
kwargs["ignore_mismatched_sizes"] = True
|
||||
|
||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||
print("Note: If there are warnings about mismatched during the loading process, "
|
||||
"please ignore them as it is part of the normal flow. "
|
||||
"The model will be reconverted to the format of BigDL after loading.")
|
||||
|
||||
# Note that the ggml_matmul_src1_x_src0_t operation cannot currently
|
||||
# be recorded in AutoConfig,
|
||||
# and this operation is not included in the core Hugging Face infrastructure.
|
||||
if bigdl_transformers_int4:
|
||||
from .convert import ggml_convert_int4
|
||||
# We forcefully modify the model's definition
|
||||
# and the tensor shape of int4 weights without quantization.
|
||||
model = ggml_convert_int4(model, convert_shape_only=True)
|
||||
# Load the quantized model at last.
|
||||
archive_file = extract_local_archive_file(pretrained_model_name_or_path,
|
||||
subfolder,
|
||||
variant)
|
||||
state_dict = load_state_dict(archive_file)
|
||||
load(model, state_dict)
|
||||
del state_dict
|
||||
elif load_in_4bit:
|
||||
from .convert import ggml_convert_int4
|
||||
model = model.to("cpu")
|
||||
model = ggml_convert_int4(model)
|
||||
model.config.update({"bigdl_transformers_int4": True})
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
|||
94
python/llm/src/bigdl/llm/transformers/utils.py
Normal file
94
python/llm/src/bigdl/llm/transformers/utils.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# Some parts of this file is adapted from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py
|
||||
# which is licensed under the MIT license:
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
import os
|
||||
from transformers.modeling_utils import _add_variant
|
||||
from ..utils.common import invalidInputError
|
||||
from typing import Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
|
||||
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant):
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
print(os.path.join(pretrained_model_name_or_path,
|
||||
subfolder,
|
||||
_add_variant(WEIGHTS_NAME, variant)))
|
||||
if os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
|
||||
):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(
|
||||
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
|
||||
)
|
||||
return archive_file
|
||||
else:
|
||||
invalidInputError(False,
|
||||
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}"
|
||||
" found in directory"
|
||||
f" {pretrained_model_name_or_path}.")
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||
try:
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
except Exception as e:
|
||||
invalidInputError(False,
|
||||
f"Unable to load weights"
|
||||
"from pytorch checkpoint file for '{checkpoint_file}' "
|
||||
f"at '{checkpoint_file}'. ")
|
||||
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: nn.Module, state_dict, prefix=""):
|
||||
args = (state_dict, prefix, {}, True, [], [], [])
|
||||
# Parameters of module and children will start with prefix.
|
||||
# We can exit early if there are none in this state_dict
|
||||
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
||||
module._load_from_state_dict(*args)
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, state_dict, prefix + name + ".")
|
||||
|
|
@ -17,9 +17,11 @@
|
|||
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
from bigdl.llm import llm_convert
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH')
|
||||
|
|
@ -62,6 +64,14 @@ class TestConvertModel(TestCase):
|
|||
outtype='int4')
|
||||
assert os.path.isfile(converted_model_path)
|
||||
|
||||
def test_transformer_convert_llama(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(llama_model_path,
|
||||
load_in_4bit=True)
|
||||
tempdir = tempfile.mkdtemp(dir=output_dir)
|
||||
model.save_pretrained(tempdir)
|
||||
model = AutoModelForCausalLM.from_pretrained(tempdir)
|
||||
assert model is not None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
|
|||
Loading…
Reference in a new issue