Add load_low_bit save_load_bit to AutoModelForCausalLM (#8531)

* transformers save_low_bit load_low_bit

* update example and add readme

* update

* update

* update

* add ut

* update
This commit is contained in:
Xin Qiu 2023-07-17 15:29:55 +08:00 committed by GitHub
parent 808a64d53a
commit fccae91461
4 changed files with 202 additions and 58 deletions

View file

@ -0,0 +1,43 @@
# BigDL-LLM Transformers INT4 Inference Pipeline for Large Language Model
In this example, we show a pipeline to apply BigDL-LLM low-bit optimizations to any Hugging Face Transformers model, and then run inference on the optimized low-bit model.
## Prepare Environment
We suggest using conda to manage environment:
```bash
conda create -n llm python=3.9
conda activate llm
pip install --pre --upgrade bigdl-llm[all]
```
## Run Example
```bash
python ./transformers_low_bit_pipeline.py --model-path decapoda-research/llama-7b-hf --low-bit sym_int5 --save-path ./llama-7b-sym_int5
```
arguments info:
- `--repo-id-or-model-path`: str value, argument defining the huggingface repo id for the large language model to be downloaded, or the path to the huggingface checkpoint folder, the value is 'decapoda-research/llama-7b-hf' by default.
- `--low-bit`: str value, options are sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8. (sym_int4 means symmetric int 4, asym_int4 means asymmetric int 4, etc.). Relevant low bit optimizations will be applied to the model.
- `--save-path`: str value, the path to save the low-bit model. Then you can load the low-bit directly.
- `--load-path`: optional str value. The path to load low-bit model.
## Sample Output for Inference
### 'decapoda-research/llama-7b-hf' Model
```log
Prompt: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun
Output: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. She wanted to be a princess, and she wanted to be a pirate. She wanted to be a superhero, and she wanted to be
Model and tokenizer are saved to ./llama-7b-sym_int5
```
### Load low-bit model
Command to run:
```bash
python ./transformers_low_bit_pipeline.py --load-path ./llama-7b-sym_int5
```
Output log:
```log
Prompt: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun
Output: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. She wanted to be a princess, and she wanted to be a pirate. She wanted to be a superhero, and she wanted to be
```

View file

@ -0,0 +1,56 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
from bigdl.llm.transformers import AutoModelForCausalLM
from transformers import LlamaTokenizer, TextGenerationPipeline
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Transformer save_load example')
parser.add_argument('--model-path', type=str, default="decapoda-research/llama-7b-hf",
help='The huggingface repo id for the large language model to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--low-bit', type=str, default="sym_int4",
choices=['sym_int4', 'asym_int4', 'sym_int5', 'asym_int5', 'sym_int8'],
help='The quantization type the model will convert to.')
parser.add_argument('--save-path', type=str, default=None,
help='The path to save the low-bit model.')
parser.add_argument('--load-path', type=str, default=None,
help='The path to load the low-bit model.')
args = parser.parse_args()
model_path = args.model_path
low_bit = args.low_bit
load_path = args.load_path
if load_path:
model = AutoModelForCausalLM.load_low_bit(load_path)
tokenizer = LlamaTokenizer.from_pretrained(load_path)
else:
# load_in_low_bit in bigdl.llm.transformers will convert
# the relevant layers in the model into corresponding int X format
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit)
tokenizer = LlamaTokenizer.from_pretrained(model_path)
pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer, max_new_tokens=32)
input_str = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun"
output = pipeline(input_str)[0]["generated_text"]
print(f"Prompt: {input_str}")
print(f"Output: {output}")
save_path = args.save_path
if save_path:
model.save_low_bit(save_path)
tokenizer.save_pretrained(save_path)
print(f"Model and tokenizer are saved to {save_path}")

View file

@ -21,6 +21,13 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError
def save_low_bit(self, *args, **kwargs):
invalidInputError(self.config.to_dict().get("bigdl_transformers_low_bit", False),
f"Detected this model is not a low-bit model, please use from_pretrained's"
f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.")
self.save_pretrained(*args, **kwargs)
class _BaseAutoModelClass:
HF_MODEL = None
@ -36,76 +43,38 @@ class _BaseAutoModelClass:
Two new arguments are added to extend Hugging Face's from_pretrained method as follows:
New Arguments:
load_in_4bit: boolean value, True means load linear's weight to symmetric int 4.
load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5 or
sym_int8. The model's linear will be loaded into corresponding
low-bit type. sym_int4 means symmetric int 4, asym_int4 means
asymmetric int 4.
load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5
or sym_int8. (sym_int4 means symmetric int 4, asym_int4 means
asymmetric int 4, etc.). Relevant low bit optimizations will
be applied to the model.
"""
# For huggingface transformers cls.HF_Model.from_pretrained could only restore the model
# in the original format, which is not quantized,
# we can convert the model to quantized later.
model = None
load_in_4bit = kwargs.pop("load_in_4bit", False)
load_in_low_bit = kwargs.pop("load_in_low_bit", None)
# Read bigdl_transformers_low_bit from config.json
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
if len(args) == 0 else args[0]
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
invalidInputError(not bigdl_transformers_low_bit,
f"Detected model is a low-bit({bigdl_transformers_low_bit}) model, "
f"Please use load_low_bit to load this model.")
if load_in_4bit or load_in_low_bit or bigdl_transformers_low_bit:
# Speed up when loading model
# For huggingface transformers cls.HF_Model.from_pretrained could only restore the model
# in the original format, which is not quantized,
# we can convert the model to quantized later.
load_in_4bit = kwargs.pop("load_in_4bit", False)
load_in_low_bit = kwargs.pop("load_in_low_bit", None)
if load_in_4bit or load_in_low_bit:
# load int x-bit
kwargs["low_cpu_mem_usage"] = True
if bigdl_transformers_low_bit:
invalidInputError(bigdl_transformers_low_bit in ggml_tensor_qtype,
f"Unknown bigdl_transformers_low_bit value:"
f" {bigdl_transformers_low_bit},"
f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
# Note that the int4 linear layers cannot currently
# be recorded in huggingface Pretrained Model or AutoConfig,
# and huggingface transformers cls.HF_Model.from_pretrained
# could only restore the model in the original format,
# which is not quantized. we can Initialize original model first,
# convert the model to quantized int4 format later, and then load the quantized model.
# Avoid KeyError
kwargs["ignore_mismatched_sizes"] = True
# Avoid reading from local file at the first initialization
kwargs["state_dict"] = {}
# Maybe needed when extract_local_archive_file
subfolder = kwargs.get("subfolder", "")
variant = kwargs.get("variant", None)
from .convert import ggml_convert_quant
model = cls.HF_Model.from_pretrained(*args, **kwargs)
print("Note: If there are warnings during the model loading process, "
"they can be safely ignored; "
"the model will be loaded with INT4 optimizations applied.")
# We forcefully modify the model's definition
# and the tensor shape of int4 weights without quantization.
model = ggml_convert_quant(model, qtype, convert_shape_only=True)
# Load the quantized model at last.
archive_file = extract_local_archive_file(pretrained_model_name_or_path,
subfolder,
variant)
state_dict = load_state_dict(archive_file)
load(model, state_dict)
del state_dict
elif load_in_4bit or load_in_low_bit:
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
model = cls.convert_quant(model, q_k, *args, **kwargs)
model = cls.load_convert(q_k, *args, **kwargs)
else:
# load default
model = cls.HF_Model.from_pretrained(*args, **kwargs)
return model
@classmethod
def convert_quant(cls, model, q_k, *args, **kwargs):
def load_convert(cls, q_k, *args, **kwargs):
from .convert import ggml_convert_quant
invalidInputError(q_k in ggml_tensor_qtype,
f"Unknown load_in_low_bit value: {q_k}, expected:"
@ -115,6 +84,73 @@ class _BaseAutoModelClass:
model = model.to("cpu")
model = ggml_convert_quant(model, qtype)
model.config.update({"bigdl_transformers_low_bit": q_k})
# add save_low_bit to pretrained model dynamically
import types
model.save_low_bit = types.MethodType(save_low_bit, model)
return model
@classmethod
def load_low_bit(cls,
*args,
**kwargs):
# Read bigdl_transformers_low_bit from config.json
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
if len(args) == 0 else args[0]
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
invalidInputError(bigdl_transformers_low_bit,
"Detect this model is not a low-bit model, Please use from_pretrained"
" with load_in_4bit or load_in_low_bit to get a low-bit model , and "
" serialize the model using save_low_bit first.")
invalidInputError(bigdl_transformers_low_bit in ggml_tensor_qtype,
f"Unknown bigdl_transformers_low_bit value: {bigdl_transformers_low_bit},"
f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
# Speed up when loading model
kwargs["low_cpu_mem_usage"] = True
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
# Note that the int4 linear layers cannot currently
# be recorded in huggingface Pretrained Model or AutoConfig,
# and huggingface transformers cls.HF_Model.from_pretrained
# could only restore the model in the original format,
# which is not quantized. we can Initialize original model first,
# convert the model to quantized int4 format later, and then load the quantized model.
# Avoid KeyError
kwargs["ignore_mismatched_sizes"] = True
# Avoid reading from local file at the first initialization
kwargs["state_dict"] = {}
# Maybe needed when extract_local_archive_file
subfolder = kwargs.get("subfolder", "")
variant = kwargs.get("variant", None)
from .convert import ggml_convert_quant
model = cls.HF_Model.from_pretrained(*args, **kwargs)
print("Note: If there are warnings during the model loading process, "
"they can be safely ignored; "
"the model will be loaded with INT4 optimizations applied.")
# add save_low_bit to pretrained model dynamically
import types
model.save_low_bit = types.MethodType(save_low_bit, model)
# We forcefully modify the model's definition
# and the tensor shape of int4 weights without quantization.
model = ggml_convert_quant(model, qtype, convert_shape_only=True)
# Load the quantized model at last.
archive_file = extract_local_archive_file(pretrained_model_name_or_path,
subfolder,
variant)
state_dict = load_state_dict(archive_file)
load(model, state_dict)
del state_dict
return model

View file

@ -79,6 +79,15 @@ class TestConvertModel(TestCase):
model = AutoModelForCausalLM.from_pretrained(llama_model_path,
load_in_low_bit="sym_int8")
def test_transformer_convert_llama_save_load(self):
model = AutoModelForCausalLM.from_pretrained(llama_model_path,
load_in_low_bit="asym_int4")
tempdir = tempfile.mkdtemp(dir=output_dir)
model.save_low_bit(tempdir)
newModel = AutoModelForCausalLM.load_low_bit(tempdir)
import shutil
shutil.rmtree(tempdir)
if __name__ == '__main__':
pytest.main([__file__])