From fccae914615d2c73448ef53e761cf16e165c0e78 Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Mon, 17 Jul 2023 15:29:55 +0800 Subject: [PATCH] Add load_low_bit save_load_bit to AutoModelForCausalLM (#8531) * transformers save_low_bit load_low_bit * update example and add readme * update * update * update * add ut * update --- .../transformers_low_bit/README.md | 43 +++++ .../transformers_low_bit_pipeline.py | 56 +++++++ .../llm/src/bigdl/llm/transformers/model.py | 152 +++++++++++------- python/llm/test/convert/test_convert_model.py | 9 ++ 4 files changed, 202 insertions(+), 58 deletions(-) create mode 100644 python/llm/example/transformers/transformers_low_bit/README.md create mode 100644 python/llm/example/transformers/transformers_low_bit/transformers_low_bit_pipeline.py diff --git a/python/llm/example/transformers/transformers_low_bit/README.md b/python/llm/example/transformers/transformers_low_bit/README.md new file mode 100644 index 00000000..4bf23fae --- /dev/null +++ b/python/llm/example/transformers/transformers_low_bit/README.md @@ -0,0 +1,43 @@ +# BigDL-LLM Transformers INT4 Inference Pipeline for Large Language Model + +In this example, we show a pipeline to apply BigDL-LLM low-bit optimizations to any Hugging Face Transformers model, and then run inference on the optimized low-bit model. + +## Prepare Environment +We suggest using conda to manage environment: +```bash +conda create -n llm python=3.9 +conda activate llm + +pip install --pre --upgrade bigdl-llm[all] +``` + +## Run Example +```bash +python ./transformers_low_bit_pipeline.py --model-path decapoda-research/llama-7b-hf --low-bit sym_int5 --save-path ./llama-7b-sym_int5 +``` +arguments info: +- `--repo-id-or-model-path`: str value, argument defining the huggingface repo id for the large language model to be downloaded, or the path to the huggingface checkpoint folder, the value is 'decapoda-research/llama-7b-hf' by default. +- `--low-bit`: str value, options are sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8. (sym_int4 means symmetric int 4, asym_int4 means asymmetric int 4, etc.). Relevant low bit optimizations will be applied to the model. +- `--save-path`: str value, the path to save the low-bit model. Then you can load the low-bit directly. +- `--load-path`: optional str value. The path to load low-bit model. + + +## Sample Output for Inference +### 'decapoda-research/llama-7b-hf' Model +```log +Prompt: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun +Output: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. She wanted to be a princess, and she wanted to be a pirate. She wanted to be a superhero, and she wanted to be +Model and tokenizer are saved to ./llama-7b-sym_int5 +``` + +### Load low-bit model +Command to run: +```bash +python ./transformers_low_bit_pipeline.py --load-path ./llama-7b-sym_int5 +``` +Output log: +```log +Prompt: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun +Output: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. She wanted to be a princess, and she wanted to be a pirate. She wanted to be a superhero, and she wanted to be +``` + diff --git a/python/llm/example/transformers/transformers_low_bit/transformers_low_bit_pipeline.py b/python/llm/example/transformers/transformers_low_bit/transformers_low_bit_pipeline.py new file mode 100644 index 00000000..d2ae7cab --- /dev/null +++ b/python/llm/example/transformers/transformers_low_bit/transformers_low_bit_pipeline.py @@ -0,0 +1,56 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +from bigdl.llm.transformers import AutoModelForCausalLM +from transformers import LlamaTokenizer, TextGenerationPipeline + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Transformer save_load example') + parser.add_argument('--model-path', type=str, default="decapoda-research/llama-7b-hf", + help='The huggingface repo id for the large language model to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--low-bit', type=str, default="sym_int4", + choices=['sym_int4', 'asym_int4', 'sym_int5', 'asym_int5', 'sym_int8'], + help='The quantization type the model will convert to.') + parser.add_argument('--save-path', type=str, default=None, + help='The path to save the low-bit model.') + parser.add_argument('--load-path', type=str, default=None, + help='The path to load the low-bit model.') + args = parser.parse_args() + model_path = args.model_path + low_bit = args.low_bit + load_path = args.load_path + if load_path: + model = AutoModelForCausalLM.load_low_bit(load_path) + tokenizer = LlamaTokenizer.from_pretrained(load_path) + else: + # load_in_low_bit in bigdl.llm.transformers will convert + # the relevant layers in the model into corresponding int X format + model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit) + tokenizer = LlamaTokenizer.from_pretrained(model_path) + + pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer, max_new_tokens=32) + input_str = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun" + output = pipeline(input_str)[0]["generated_text"] + print(f"Prompt: {input_str}") + print(f"Output: {output}") + + save_path = args.save_path + if save_path: + model.save_low_bit(save_path) + tokenizer.save_pretrained(save_path) + print(f"Model and tokenizer are saved to {save_path}") diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index c4293de3..85e1015a 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -21,6 +21,13 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError +def save_low_bit(self, *args, **kwargs): + invalidInputError(self.config.to_dict().get("bigdl_transformers_low_bit", False), + f"Detected this model is not a low-bit model, please use from_pretrained's" + f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.") + self.save_pretrained(*args, **kwargs) + + class _BaseAutoModelClass: HF_MODEL = None @@ -36,76 +43,38 @@ class _BaseAutoModelClass: Two new arguments are added to extend Hugging Face's from_pretrained method as follows: New Arguments: load_in_4bit: boolean value, True means load linear's weight to symmetric int 4. - load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5 or - sym_int8. The model's linear will be loaded into corresponding - low-bit type. sym_int4 means symmetric int 4, asym_int4 means - asymmetric int 4. + load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5 + or sym_int8. (sym_int4 means symmetric int 4, asym_int4 means + asymmetric int 4, etc.). Relevant low bit optimizations will + be applied to the model. """ - - # For huggingface transformers cls.HF_Model.from_pretrained could only restore the model - # in the original format, which is not quantized, - # we can convert the model to quantized later. - model = None - load_in_4bit = kwargs.pop("load_in_4bit", False) - load_in_low_bit = kwargs.pop("load_in_low_bit", None) - - # Read bigdl_transformers_low_bit from config.json pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \ if len(args) == 0 else args[0] config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path) bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False) + invalidInputError(not bigdl_transformers_low_bit, + f"Detected model is a low-bit({bigdl_transformers_low_bit}) model, " + f"Please use load_low_bit to load this model.") - if load_in_4bit or load_in_low_bit or bigdl_transformers_low_bit: - # Speed up when loading model + # For huggingface transformers cls.HF_Model.from_pretrained could only restore the model + # in the original format, which is not quantized, + # we can convert the model to quantized later. + load_in_4bit = kwargs.pop("load_in_4bit", False) + load_in_low_bit = kwargs.pop("load_in_low_bit", None) + + if load_in_4bit or load_in_low_bit: + # load int x-bit kwargs["low_cpu_mem_usage"] = True - - if bigdl_transformers_low_bit: - invalidInputError(bigdl_transformers_low_bit in ggml_tensor_qtype, - f"Unknown bigdl_transformers_low_bit value:" - f" {bigdl_transformers_low_bit}," - f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.") - qtype = ggml_tensor_qtype[bigdl_transformers_low_bit] - # Note that the int4 linear layers cannot currently - # be recorded in huggingface Pretrained Model or AutoConfig, - # and huggingface transformers cls.HF_Model.from_pretrained - # could only restore the model in the original format, - # which is not quantized. we can Initialize original model first, - # convert the model to quantized int4 format later, and then load the quantized model. - - # Avoid KeyError - kwargs["ignore_mismatched_sizes"] = True - # Avoid reading from local file at the first initialization - kwargs["state_dict"] = {} - - # Maybe needed when extract_local_archive_file - subfolder = kwargs.get("subfolder", "") - variant = kwargs.get("variant", None) - - from .convert import ggml_convert_quant - model = cls.HF_Model.from_pretrained(*args, **kwargs) - print("Note: If there are warnings during the model loading process, " - "they can be safely ignored; " - "the model will be loaded with INT4 optimizations applied.") - - # We forcefully modify the model's definition - # and the tensor shape of int4 weights without quantization. - model = ggml_convert_quant(model, qtype, convert_shape_only=True) - # Load the quantized model at last. - archive_file = extract_local_archive_file(pretrained_model_name_or_path, - subfolder, - variant) - state_dict = load_state_dict(archive_file) - load(model, state_dict) - del state_dict - - elif load_in_4bit or load_in_low_bit: q_k = load_in_low_bit if load_in_low_bit else "sym_int4" - model = cls.convert_quant(model, q_k, *args, **kwargs) + model = cls.load_convert(q_k, *args, **kwargs) + else: + # load default + model = cls.HF_Model.from_pretrained(*args, **kwargs) return model @classmethod - def convert_quant(cls, model, q_k, *args, **kwargs): + def load_convert(cls, q_k, *args, **kwargs): from .convert import ggml_convert_quant invalidInputError(q_k in ggml_tensor_qtype, f"Unknown load_in_low_bit value: {q_k}, expected:" @@ -115,6 +84,73 @@ class _BaseAutoModelClass: model = model.to("cpu") model = ggml_convert_quant(model, qtype) model.config.update({"bigdl_transformers_low_bit": q_k}) + + # add save_low_bit to pretrained model dynamically + import types + model.save_low_bit = types.MethodType(save_low_bit, model) + + return model + + @classmethod + def load_low_bit(cls, + *args, + **kwargs): + # Read bigdl_transformers_low_bit from config.json + pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \ + if len(args) == 0 else args[0] + config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path) + bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False) + + invalidInputError(bigdl_transformers_low_bit, + "Detect this model is not a low-bit model, Please use from_pretrained" + " with load_in_4bit or load_in_low_bit to get a low-bit model , and " + " serialize the model using save_low_bit first.") + + invalidInputError(bigdl_transformers_low_bit in ggml_tensor_qtype, + f"Unknown bigdl_transformers_low_bit value: {bigdl_transformers_low_bit}," + f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.") + + # Speed up when loading model + kwargs["low_cpu_mem_usage"] = True + + qtype = ggml_tensor_qtype[bigdl_transformers_low_bit] + # Note that the int4 linear layers cannot currently + # be recorded in huggingface Pretrained Model or AutoConfig, + # and huggingface transformers cls.HF_Model.from_pretrained + # could only restore the model in the original format, + # which is not quantized. we can Initialize original model first, + # convert the model to quantized int4 format later, and then load the quantized model. + + # Avoid KeyError + kwargs["ignore_mismatched_sizes"] = True + # Avoid reading from local file at the first initialization + kwargs["state_dict"] = {} + + # Maybe needed when extract_local_archive_file + subfolder = kwargs.get("subfolder", "") + variant = kwargs.get("variant", None) + + from .convert import ggml_convert_quant + model = cls.HF_Model.from_pretrained(*args, **kwargs) + print("Note: If there are warnings during the model loading process, " + "they can be safely ignored; " + "the model will be loaded with INT4 optimizations applied.") + + # add save_low_bit to pretrained model dynamically + import types + model.save_low_bit = types.MethodType(save_low_bit, model) + + # We forcefully modify the model's definition + # and the tensor shape of int4 weights without quantization. + model = ggml_convert_quant(model, qtype, convert_shape_only=True) + # Load the quantized model at last. + archive_file = extract_local_archive_file(pretrained_model_name_or_path, + subfolder, + variant) + state_dict = load_state_dict(archive_file) + load(model, state_dict) + del state_dict + return model diff --git a/python/llm/test/convert/test_convert_model.py b/python/llm/test/convert/test_convert_model.py index d2a7d794..aeb4fa80 100644 --- a/python/llm/test/convert/test_convert_model.py +++ b/python/llm/test/convert/test_convert_model.py @@ -79,6 +79,15 @@ class TestConvertModel(TestCase): model = AutoModelForCausalLM.from_pretrained(llama_model_path, load_in_low_bit="sym_int8") + def test_transformer_convert_llama_save_load(self): + model = AutoModelForCausalLM.from_pretrained(llama_model_path, + load_in_low_bit="asym_int4") + tempdir = tempfile.mkdtemp(dir=output_dir) + model.save_low_bit(tempdir) + newModel = AutoModelForCausalLM.load_low_bit(tempdir) + import shutil + shutil.rmtree(tempdir) + if __name__ == '__main__': pytest.main([__file__])