LLM: transformer int4 save and load (#8462)
* LLM: transformer int4 save and load
This commit is contained in:
		
							parent
							
								
									04f2f04410
								
							
						
					
					
						commit
						81d655cda9
					
				
					 5 changed files with 164 additions and 14 deletions
				
			
		| 
						 | 
				
			
			@ -41,7 +41,8 @@ from bigdl.llm.transformers.linear_int4 import LinearInt4, ParamsInt4
 | 
			
		|||
import warnings
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _replace_with_int4_linear(model, modules_to_not_convert=None, current_key_name=None):
 | 
			
		||||
def _replace_with_int4_linear(model, modules_to_not_convert=None,
 | 
			
		||||
                              current_key_name=None, convert_shape_only=False):
 | 
			
		||||
    has_been_replaced = False
 | 
			
		||||
    for name, module in model.named_children():
 | 
			
		||||
        if current_key_name is None:
 | 
			
		||||
| 
						 | 
				
			
			@ -59,10 +60,12 @@ def _replace_with_int4_linear(model, modules_to_not_convert=None, current_key_na
 | 
			
		|||
                    )
 | 
			
		||||
 | 
			
		||||
                    # Copy the weights
 | 
			
		||||
                    new_linear._parameters['weight'] = ParamsInt4(data=module.weight.data,
 | 
			
		||||
                                                                  requires_grad=False,
 | 
			
		||||
                                                                  quantized=False,
 | 
			
		||||
                                                                  _shape=None).to("cpu")
 | 
			
		||||
                    paramsint4 = ParamsInt4(data=module.weight.data,
 | 
			
		||||
                                            requires_grad=False,
 | 
			
		||||
                                            quantized=False,
 | 
			
		||||
                                            convert_shape_only=convert_shape_only,
 | 
			
		||||
                                            _shape=None).to("cpu")
 | 
			
		||||
                    new_linear._parameters['weight'] = paramsint4
 | 
			
		||||
                    if module.bias is not None:
 | 
			
		||||
                        new_linear._parameters['bias'] = nn.Parameter(module.bias.data).to("cpu")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -83,10 +86,10 @@ def _replace_with_int4_linear(model, modules_to_not_convert=None, current_key_na
 | 
			
		|||
    return model, has_been_replaced
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ggml_convert_int4(model):
 | 
			
		||||
def ggml_convert_int4(model, convert_shape_only=False):
 | 
			
		||||
    modules_to_not_convert = []  # ["lm_head"]
 | 
			
		||||
    model, has_been_replaced = _replace_with_int4_linear(
 | 
			
		||||
        model, modules_to_not_convert, None
 | 
			
		||||
        model, modules_to_not_convert, None, convert_shape_only=convert_shape_only
 | 
			
		||||
    )
 | 
			
		||||
    if not has_been_replaced:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -60,7 +60,7 @@ scale_size_in_bytes = 4
 | 
			
		|||
block_size_in_bytes = QK // 2 + scale_size_in_bytes
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ggml_convert_int4(tensor: torch.Tensor):
 | 
			
		||||
def ggml_convert_int4(tensor: torch.Tensor, convert_shape_only=False):
 | 
			
		||||
 | 
			
		||||
    invalidInputError(tensor.dtype == torch.float,
 | 
			
		||||
                      "Input tensor must be float32")
 | 
			
		||||
| 
						 | 
				
			
			@ -79,12 +79,14 @@ def ggml_convert_int4(tensor: torch.Tensor):
 | 
			
		|||
 | 
			
		||||
    hist = (ctypes.c_int64 * 16)()
 | 
			
		||||
 | 
			
		||||
    ggml.ggml_quantize_q4_0(src, dst, n, k, hist)
 | 
			
		||||
    if not convert_shape_only:
 | 
			
		||||
        ggml.ggml_quantize_q4_0(src, dst, n, k, hist)
 | 
			
		||||
    return dst_tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParamsInt4(torch.nn.Parameter):
 | 
			
		||||
    def __new__(cls, data=None, requires_grad=True, old_data=None, quantized=False, _shape=None):
 | 
			
		||||
    def __new__(cls, data=None, requires_grad=True, old_data=None,
 | 
			
		||||
                quantized=False, _shape=None, convert_shape_only=False):
 | 
			
		||||
        if data is None:
 | 
			
		||||
            data = torch.empty(0)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -92,13 +94,14 @@ class ParamsInt4(torch.nn.Parameter):
 | 
			
		|||
        self.data = data
 | 
			
		||||
        self.quantized = quantized
 | 
			
		||||
        self._shape = _shape
 | 
			
		||||
        self.convert_shape_only = convert_shape_only
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def quantize(self, device):
 | 
			
		||||
        if not self.quantized:
 | 
			
		||||
            w = self.data.contiguous().float()
 | 
			
		||||
            # self.old_data = self.data
 | 
			
		||||
            w_4bit = ggml_convert_int4(w)
 | 
			
		||||
            w_4bit = ggml_convert_int4(w, convert_shape_only=self.convert_shape_only)
 | 
			
		||||
            self.data = w_4bit
 | 
			
		||||
            self.quantized = True
 | 
			
		||||
            self._shape = w.shape
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,7 +15,8 @@
 | 
			
		|||
#
 | 
			
		||||
 | 
			
		||||
import transformers
 | 
			
		||||
import torch
 | 
			
		||||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
from .utils import extract_local_archive_file, load_state_dict, load
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _BaseAutoModelClass:
 | 
			
		||||
| 
						 | 
				
			
			@ -29,12 +30,51 @@ class _BaseAutoModelClass:
 | 
			
		|||
        load_in_4bit = kwargs.pop("load_in_4bit", False)
 | 
			
		||||
        if load_in_4bit:
 | 
			
		||||
            kwargs["low_cpu_mem_usage"] = True
 | 
			
		||||
        model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        if load_in_4bit:
 | 
			
		||||
        subfolder = kwargs.get("subfolder", "")
 | 
			
		||||
        variant = kwargs.get("variant", None)
 | 
			
		||||
        pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
 | 
			
		||||
            if len(args) == 0 else args[0]
 | 
			
		||||
 | 
			
		||||
        # For huggingface transformers cls.HF_Model.from_pretrained could only restore the model
 | 
			
		||||
        # in the original format, which is not quantized,
 | 
			
		||||
        # we can convert the model to quantized later.
 | 
			
		||||
        model = None
 | 
			
		||||
 | 
			
		||||
        # Read bigdl_transformers_int4 from config.json
 | 
			
		||||
        config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
 | 
			
		||||
 | 
			
		||||
        bigdl_transformers_int4 = config_dict.pop("bigdl_transformers_int4", False)
 | 
			
		||||
        if bigdl_transformers_int4:
 | 
			
		||||
            # Avoid KeyError
 | 
			
		||||
            kwargs["ignore_mismatched_sizes"] = True
 | 
			
		||||
 | 
			
		||||
        model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
			
		||||
        print("Note: If there are warnings about mismatched during the loading process, "
 | 
			
		||||
              "please ignore them as it is part of the normal flow. "
 | 
			
		||||
              "The model will be reconverted to the format of BigDL after loading.")
 | 
			
		||||
 | 
			
		||||
        # Note that the ggml_matmul_src1_x_src0_t operation cannot currently
 | 
			
		||||
        # be recorded in AutoConfig,
 | 
			
		||||
        # and this operation is not included in the core Hugging Face infrastructure.
 | 
			
		||||
        if bigdl_transformers_int4:
 | 
			
		||||
            from .convert import ggml_convert_int4
 | 
			
		||||
            # We forcefully modify the model's definition
 | 
			
		||||
            # and the tensor shape of int4 weights without quantization.
 | 
			
		||||
            model = ggml_convert_int4(model, convert_shape_only=True)
 | 
			
		||||
            # Load the quantized model at last.
 | 
			
		||||
            archive_file = extract_local_archive_file(pretrained_model_name_or_path,
 | 
			
		||||
                                                      subfolder,
 | 
			
		||||
                                                      variant)
 | 
			
		||||
            state_dict = load_state_dict(archive_file)
 | 
			
		||||
            load(model, state_dict)
 | 
			
		||||
            del state_dict
 | 
			
		||||
        elif load_in_4bit:
 | 
			
		||||
            from .convert import ggml_convert_int4
 | 
			
		||||
            model = model.to("cpu")
 | 
			
		||||
            model = ggml_convert_int4(model)
 | 
			
		||||
            model.config.update({"bigdl_transformers_int4": True})
 | 
			
		||||
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										94
									
								
								python/llm/src/bigdl/llm/transformers/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								python/llm/src/bigdl/llm/transformers/utils.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,94 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py
 | 
			
		||||
# which is licensed under the MIT license:
 | 
			
		||||
#
 | 
			
		||||
# MIT License
 | 
			
		||||
#
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates.
 | 
			
		||||
#
 | 
			
		||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
			
		||||
# of this software and associated documentation files (the "Software"), to deal
 | 
			
		||||
# in the Software without restriction, including without limitation the rights
 | 
			
		||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
			
		||||
# copies of the Software, and to permit persons to whom the Software is
 | 
			
		||||
# furnished to do so, subject to the following conditions:
 | 
			
		||||
 | 
			
		||||
# The above copyright notice and this permission notice shall be included in all
 | 
			
		||||
# copies or substantial portions of the Software.
 | 
			
		||||
 | 
			
		||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
			
		||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
			
		||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
			
		||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
			
		||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
			
		||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
			
		||||
# SOFTWARE.
 | 
			
		||||
import os
 | 
			
		||||
from transformers.modeling_utils import _add_variant
 | 
			
		||||
from ..utils.common import invalidInputError
 | 
			
		||||
from typing import Union
 | 
			
		||||
import torch
 | 
			
		||||
from torch import nn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
WEIGHTS_NAME = "pytorch_model.bin"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant):
 | 
			
		||||
    pretrained_model_name_or_path = str(pretrained_model_name_or_path)
 | 
			
		||||
    print(os.path.join(pretrained_model_name_or_path,
 | 
			
		||||
                       subfolder,
 | 
			
		||||
                       _add_variant(WEIGHTS_NAME, variant)))
 | 
			
		||||
    if os.path.isfile(
 | 
			
		||||
        os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
 | 
			
		||||
    ):
 | 
			
		||||
        # Load from a PyTorch checkpoint
 | 
			
		||||
        archive_file = os.path.join(
 | 
			
		||||
            pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
 | 
			
		||||
        )
 | 
			
		||||
        return archive_file
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}"
 | 
			
		||||
                          " found in directory"
 | 
			
		||||
                          f" {pretrained_model_name_or_path}.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
 | 
			
		||||
    try:
 | 
			
		||||
        return torch.load(checkpoint_file, map_location="cpu")
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"Unable to load weights"
 | 
			
		||||
                          "from pytorch checkpoint file for '{checkpoint_file}' "
 | 
			
		||||
                          f"at '{checkpoint_file}'. ")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
 | 
			
		||||
# so we need to apply the function recursively.
 | 
			
		||||
def load(module: nn.Module, state_dict, prefix=""):
 | 
			
		||||
    args = (state_dict, prefix, {}, True, [], [], [])
 | 
			
		||||
    # Parameters of module and children will start with prefix.
 | 
			
		||||
    # We can exit early if there are none in this state_dict
 | 
			
		||||
    if len([key for key in state_dict if key.startswith(prefix)]) > 0:
 | 
			
		||||
            module._load_from_state_dict(*args)
 | 
			
		||||
 | 
			
		||||
    for name, child in module._modules.items():
 | 
			
		||||
        if child is not None:
 | 
			
		||||
            load(child, state_dict, prefix + name + ".")
 | 
			
		||||
| 
						 | 
				
			
			@ -17,9 +17,11 @@
 | 
			
		|||
 | 
			
		||||
import pytest
 | 
			
		||||
import os
 | 
			
		||||
import tempfile
 | 
			
		||||
from unittest import TestCase
 | 
			
		||||
 | 
			
		||||
from bigdl.llm import llm_convert
 | 
			
		||||
from bigdl.llm.transformers import AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH')
 | 
			
		||||
| 
						 | 
				
			
			@ -62,6 +64,14 @@ class TestConvertModel(TestCase):
 | 
			
		|||
                                           outtype='int4')
 | 
			
		||||
        assert os.path.isfile(converted_model_path)
 | 
			
		||||
 | 
			
		||||
    def test_transformer_convert_llama(self):
 | 
			
		||||
        model = AutoModelForCausalLM.from_pretrained(llama_model_path,
 | 
			
		||||
                                                     load_in_4bit=True)
 | 
			
		||||
        tempdir = tempfile.mkdtemp(dir=output_dir)
 | 
			
		||||
        model.save_pretrained(tempdir)
 | 
			
		||||
        model = AutoModelForCausalLM.from_pretrained(tempdir)
 | 
			
		||||
        assert model is not None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    pytest.main([__file__])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue