Add load_low_bit save_load_bit to AutoModelForCausalLM (#8531)
* transformers save_low_bit load_low_bit * update example and add readme * update * update * update * add ut * update
This commit is contained in:
parent
808a64d53a
commit
fccae91461
4 changed files with 202 additions and 58 deletions
|
|
@ -0,0 +1,43 @@
|
|||
# BigDL-LLM Transformers INT4 Inference Pipeline for Large Language Model
|
||||
|
||||
In this example, we show a pipeline to apply BigDL-LLM low-bit optimizations to any Hugging Face Transformers model, and then run inference on the optimized low-bit model.
|
||||
|
||||
## Prepare Environment
|
||||
We suggest using conda to manage environment:
|
||||
```bash
|
||||
conda create -n llm python=3.9
|
||||
conda activate llm
|
||||
|
||||
pip install --pre --upgrade bigdl-llm[all]
|
||||
```
|
||||
|
||||
## Run Example
|
||||
```bash
|
||||
python ./transformers_low_bit_pipeline.py --model-path decapoda-research/llama-7b-hf --low-bit sym_int5 --save-path ./llama-7b-sym_int5
|
||||
```
|
||||
arguments info:
|
||||
- `--repo-id-or-model-path`: str value, argument defining the huggingface repo id for the large language model to be downloaded, or the path to the huggingface checkpoint folder, the value is 'decapoda-research/llama-7b-hf' by default.
|
||||
- `--low-bit`: str value, options are sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8. (sym_int4 means symmetric int 4, asym_int4 means asymmetric int 4, etc.). Relevant low bit optimizations will be applied to the model.
|
||||
- `--save-path`: str value, the path to save the low-bit model. Then you can load the low-bit directly.
|
||||
- `--load-path`: optional str value. The path to load low-bit model.
|
||||
|
||||
|
||||
## Sample Output for Inference
|
||||
### 'decapoda-research/llama-7b-hf' Model
|
||||
```log
|
||||
Prompt: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun
|
||||
Output: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. She wanted to be a princess, and she wanted to be a pirate. She wanted to be a superhero, and she wanted to be
|
||||
Model and tokenizer are saved to ./llama-7b-sym_int5
|
||||
```
|
||||
|
||||
### Load low-bit model
|
||||
Command to run:
|
||||
```bash
|
||||
python ./transformers_low_bit_pipeline.py --load-path ./llama-7b-sym_int5
|
||||
```
|
||||
Output log:
|
||||
```log
|
||||
Prompt: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun
|
||||
Output: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. She wanted to be a princess, and she wanted to be a pirate. She wanted to be a superhero, and she wanted to be
|
||||
```
|
||||
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import argparse
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from transformers import LlamaTokenizer, TextGenerationPipeline
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Transformer save_load example')
|
||||
parser.add_argument('--model-path', type=str, default="decapoda-research/llama-7b-hf",
|
||||
help='The huggingface repo id for the large language model to be downloaded'
|
||||
', or the path to the huggingface checkpoint folder')
|
||||
parser.add_argument('--low-bit', type=str, default="sym_int4",
|
||||
choices=['sym_int4', 'asym_int4', 'sym_int5', 'asym_int5', 'sym_int8'],
|
||||
help='The quantization type the model will convert to.')
|
||||
parser.add_argument('--save-path', type=str, default=None,
|
||||
help='The path to save the low-bit model.')
|
||||
parser.add_argument('--load-path', type=str, default=None,
|
||||
help='The path to load the low-bit model.')
|
||||
args = parser.parse_args()
|
||||
model_path = args.model_path
|
||||
low_bit = args.low_bit
|
||||
load_path = args.load_path
|
||||
if load_path:
|
||||
model = AutoModelForCausalLM.load_low_bit(load_path)
|
||||
tokenizer = LlamaTokenizer.from_pretrained(load_path)
|
||||
else:
|
||||
# load_in_low_bit in bigdl.llm.transformers will convert
|
||||
# the relevant layers in the model into corresponding int X format
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit)
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_path)
|
||||
|
||||
pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer, max_new_tokens=32)
|
||||
input_str = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun"
|
||||
output = pipeline(input_str)[0]["generated_text"]
|
||||
print(f"Prompt: {input_str}")
|
||||
print(f"Output: {output}")
|
||||
|
||||
save_path = args.save_path
|
||||
if save_path:
|
||||
model.save_low_bit(save_path)
|
||||
tokenizer.save_pretrained(save_path)
|
||||
print(f"Model and tokenizer are saved to {save_path}")
|
||||
|
|
@ -21,6 +21,13 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
|||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
|
||||
def save_low_bit(self, *args, **kwargs):
|
||||
invalidInputError(self.config.to_dict().get("bigdl_transformers_low_bit", False),
|
||||
f"Detected this model is not a low-bit model, please use from_pretrained's"
|
||||
f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.")
|
||||
self.save_pretrained(*args, **kwargs)
|
||||
|
||||
|
||||
class _BaseAutoModelClass:
|
||||
|
||||
HF_MODEL = None
|
||||
|
|
@ -36,76 +43,38 @@ class _BaseAutoModelClass:
|
|||
Two new arguments are added to extend Hugging Face's from_pretrained method as follows:
|
||||
New Arguments:
|
||||
load_in_4bit: boolean value, True means load linear's weight to symmetric int 4.
|
||||
load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5 or
|
||||
sym_int8. The model's linear will be loaded into corresponding
|
||||
low-bit type. sym_int4 means symmetric int 4, asym_int4 means
|
||||
asymmetric int 4.
|
||||
load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5
|
||||
or sym_int8. (sym_int4 means symmetric int 4, asym_int4 means
|
||||
asymmetric int 4, etc.). Relevant low bit optimizations will
|
||||
be applied to the model.
|
||||
"""
|
||||
|
||||
# For huggingface transformers cls.HF_Model.from_pretrained could only restore the model
|
||||
# in the original format, which is not quantized,
|
||||
# we can convert the model to quantized later.
|
||||
model = None
|
||||
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
||||
load_in_low_bit = kwargs.pop("load_in_low_bit", None)
|
||||
|
||||
# Read bigdl_transformers_low_bit from config.json
|
||||
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
||||
if len(args) == 0 else args[0]
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
||||
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
|
||||
invalidInputError(not bigdl_transformers_low_bit,
|
||||
f"Detected model is a low-bit({bigdl_transformers_low_bit}) model, "
|
||||
f"Please use load_low_bit to load this model.")
|
||||
|
||||
if load_in_4bit or load_in_low_bit or bigdl_transformers_low_bit:
|
||||
# Speed up when loading model
|
||||
# For huggingface transformers cls.HF_Model.from_pretrained could only restore the model
|
||||
# in the original format, which is not quantized,
|
||||
# we can convert the model to quantized later.
|
||||
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
||||
load_in_low_bit = kwargs.pop("load_in_low_bit", None)
|
||||
|
||||
if load_in_4bit or load_in_low_bit:
|
||||
# load int x-bit
|
||||
kwargs["low_cpu_mem_usage"] = True
|
||||
|
||||
if bigdl_transformers_low_bit:
|
||||
invalidInputError(bigdl_transformers_low_bit in ggml_tensor_qtype,
|
||||
f"Unknown bigdl_transformers_low_bit value:"
|
||||
f" {bigdl_transformers_low_bit},"
|
||||
f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
||||
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
|
||||
# Note that the int4 linear layers cannot currently
|
||||
# be recorded in huggingface Pretrained Model or AutoConfig,
|
||||
# and huggingface transformers cls.HF_Model.from_pretrained
|
||||
# could only restore the model in the original format,
|
||||
# which is not quantized. we can Initialize original model first,
|
||||
# convert the model to quantized int4 format later, and then load the quantized model.
|
||||
|
||||
# Avoid KeyError
|
||||
kwargs["ignore_mismatched_sizes"] = True
|
||||
# Avoid reading from local file at the first initialization
|
||||
kwargs["state_dict"] = {}
|
||||
|
||||
# Maybe needed when extract_local_archive_file
|
||||
subfolder = kwargs.get("subfolder", "")
|
||||
variant = kwargs.get("variant", None)
|
||||
|
||||
from .convert import ggml_convert_quant
|
||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||
print("Note: If there are warnings during the model loading process, "
|
||||
"they can be safely ignored; "
|
||||
"the model will be loaded with INT4 optimizations applied.")
|
||||
|
||||
# We forcefully modify the model's definition
|
||||
# and the tensor shape of int4 weights without quantization.
|
||||
model = ggml_convert_quant(model, qtype, convert_shape_only=True)
|
||||
# Load the quantized model at last.
|
||||
archive_file = extract_local_archive_file(pretrained_model_name_or_path,
|
||||
subfolder,
|
||||
variant)
|
||||
state_dict = load_state_dict(archive_file)
|
||||
load(model, state_dict)
|
||||
del state_dict
|
||||
|
||||
elif load_in_4bit or load_in_low_bit:
|
||||
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
||||
model = cls.convert_quant(model, q_k, *args, **kwargs)
|
||||
model = cls.load_convert(q_k, *args, **kwargs)
|
||||
else:
|
||||
# load default
|
||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def convert_quant(cls, model, q_k, *args, **kwargs):
|
||||
def load_convert(cls, q_k, *args, **kwargs):
|
||||
from .convert import ggml_convert_quant
|
||||
invalidInputError(q_k in ggml_tensor_qtype,
|
||||
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
||||
|
|
@ -115,6 +84,73 @@ class _BaseAutoModelClass:
|
|||
model = model.to("cpu")
|
||||
model = ggml_convert_quant(model, qtype)
|
||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||
|
||||
# add save_low_bit to pretrained model dynamically
|
||||
import types
|
||||
model.save_low_bit = types.MethodType(save_low_bit, model)
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def load_low_bit(cls,
|
||||
*args,
|
||||
**kwargs):
|
||||
# Read bigdl_transformers_low_bit from config.json
|
||||
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
||||
if len(args) == 0 else args[0]
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
||||
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
|
||||
|
||||
invalidInputError(bigdl_transformers_low_bit,
|
||||
"Detect this model is not a low-bit model, Please use from_pretrained"
|
||||
" with load_in_4bit or load_in_low_bit to get a low-bit model , and "
|
||||
" serialize the model using save_low_bit first.")
|
||||
|
||||
invalidInputError(bigdl_transformers_low_bit in ggml_tensor_qtype,
|
||||
f"Unknown bigdl_transformers_low_bit value: {bigdl_transformers_low_bit},"
|
||||
f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
||||
|
||||
# Speed up when loading model
|
||||
kwargs["low_cpu_mem_usage"] = True
|
||||
|
||||
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
|
||||
# Note that the int4 linear layers cannot currently
|
||||
# be recorded in huggingface Pretrained Model or AutoConfig,
|
||||
# and huggingface transformers cls.HF_Model.from_pretrained
|
||||
# could only restore the model in the original format,
|
||||
# which is not quantized. we can Initialize original model first,
|
||||
# convert the model to quantized int4 format later, and then load the quantized model.
|
||||
|
||||
# Avoid KeyError
|
||||
kwargs["ignore_mismatched_sizes"] = True
|
||||
# Avoid reading from local file at the first initialization
|
||||
kwargs["state_dict"] = {}
|
||||
|
||||
# Maybe needed when extract_local_archive_file
|
||||
subfolder = kwargs.get("subfolder", "")
|
||||
variant = kwargs.get("variant", None)
|
||||
|
||||
from .convert import ggml_convert_quant
|
||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||
print("Note: If there are warnings during the model loading process, "
|
||||
"they can be safely ignored; "
|
||||
"the model will be loaded with INT4 optimizations applied.")
|
||||
|
||||
# add save_low_bit to pretrained model dynamically
|
||||
import types
|
||||
model.save_low_bit = types.MethodType(save_low_bit, model)
|
||||
|
||||
# We forcefully modify the model's definition
|
||||
# and the tensor shape of int4 weights without quantization.
|
||||
model = ggml_convert_quant(model, qtype, convert_shape_only=True)
|
||||
# Load the quantized model at last.
|
||||
archive_file = extract_local_archive_file(pretrained_model_name_or_path,
|
||||
subfolder,
|
||||
variant)
|
||||
state_dict = load_state_dict(archive_file)
|
||||
load(model, state_dict)
|
||||
del state_dict
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -79,6 +79,15 @@ class TestConvertModel(TestCase):
|
|||
model = AutoModelForCausalLM.from_pretrained(llama_model_path,
|
||||
load_in_low_bit="sym_int8")
|
||||
|
||||
def test_transformer_convert_llama_save_load(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(llama_model_path,
|
||||
load_in_low_bit="asym_int4")
|
||||
tempdir = tempfile.mkdtemp(dir=output_dir)
|
||||
model.save_low_bit(tempdir)
|
||||
newModel = AutoModelForCausalLM.load_low_bit(tempdir)
|
||||
import shutil
|
||||
shutil.rmtree(tempdir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
|
|||
Loading…
Reference in a new issue