Add load_low_bit save_load_bit to AutoModelForCausalLM (#8531)
* transformers save_low_bit load_low_bit * update example and add readme * update * update * update * add ut * update
This commit is contained in:
parent
808a64d53a
commit
fccae91461
4 changed files with 202 additions and 58 deletions
|
|
@ -0,0 +1,43 @@
|
||||||
|
# BigDL-LLM Transformers INT4 Inference Pipeline for Large Language Model
|
||||||
|
|
||||||
|
In this example, we show a pipeline to apply BigDL-LLM low-bit optimizations to any Hugging Face Transformers model, and then run inference on the optimized low-bit model.
|
||||||
|
|
||||||
|
## Prepare Environment
|
||||||
|
We suggest using conda to manage environment:
|
||||||
|
```bash
|
||||||
|
conda create -n llm python=3.9
|
||||||
|
conda activate llm
|
||||||
|
|
||||||
|
pip install --pre --upgrade bigdl-llm[all]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run Example
|
||||||
|
```bash
|
||||||
|
python ./transformers_low_bit_pipeline.py --model-path decapoda-research/llama-7b-hf --low-bit sym_int5 --save-path ./llama-7b-sym_int5
|
||||||
|
```
|
||||||
|
arguments info:
|
||||||
|
- `--repo-id-or-model-path`: str value, argument defining the huggingface repo id for the large language model to be downloaded, or the path to the huggingface checkpoint folder, the value is 'decapoda-research/llama-7b-hf' by default.
|
||||||
|
- `--low-bit`: str value, options are sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8. (sym_int4 means symmetric int 4, asym_int4 means asymmetric int 4, etc.). Relevant low bit optimizations will be applied to the model.
|
||||||
|
- `--save-path`: str value, the path to save the low-bit model. Then you can load the low-bit directly.
|
||||||
|
- `--load-path`: optional str value. The path to load low-bit model.
|
||||||
|
|
||||||
|
|
||||||
|
## Sample Output for Inference
|
||||||
|
### 'decapoda-research/llama-7b-hf' Model
|
||||||
|
```log
|
||||||
|
Prompt: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun
|
||||||
|
Output: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. She wanted to be a princess, and she wanted to be a pirate. She wanted to be a superhero, and she wanted to be
|
||||||
|
Model and tokenizer are saved to ./llama-7b-sym_int5
|
||||||
|
```
|
||||||
|
|
||||||
|
### Load low-bit model
|
||||||
|
Command to run:
|
||||||
|
```bash
|
||||||
|
python ./transformers_low_bit_pipeline.py --load-path ./llama-7b-sym_int5
|
||||||
|
```
|
||||||
|
Output log:
|
||||||
|
```log
|
||||||
|
Prompt: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun
|
||||||
|
Output: Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. She wanted to be a princess, and she wanted to be a pirate. She wanted to be a superhero, and she wanted to be
|
||||||
|
```
|
||||||
|
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||||
|
from transformers import LlamaTokenizer, TextGenerationPipeline
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Transformer save_load example')
|
||||||
|
parser.add_argument('--model-path', type=str, default="decapoda-research/llama-7b-hf",
|
||||||
|
help='The huggingface repo id for the large language model to be downloaded'
|
||||||
|
', or the path to the huggingface checkpoint folder')
|
||||||
|
parser.add_argument('--low-bit', type=str, default="sym_int4",
|
||||||
|
choices=['sym_int4', 'asym_int4', 'sym_int5', 'asym_int5', 'sym_int8'],
|
||||||
|
help='The quantization type the model will convert to.')
|
||||||
|
parser.add_argument('--save-path', type=str, default=None,
|
||||||
|
help='The path to save the low-bit model.')
|
||||||
|
parser.add_argument('--load-path', type=str, default=None,
|
||||||
|
help='The path to load the low-bit model.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
model_path = args.model_path
|
||||||
|
low_bit = args.low_bit
|
||||||
|
load_path = args.load_path
|
||||||
|
if load_path:
|
||||||
|
model = AutoModelForCausalLM.load_low_bit(load_path)
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(load_path)
|
||||||
|
else:
|
||||||
|
# load_in_low_bit in bigdl.llm.transformers will convert
|
||||||
|
# the relevant layers in the model into corresponding int X format
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit)
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(model_path)
|
||||||
|
|
||||||
|
pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer, max_new_tokens=32)
|
||||||
|
input_str = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun"
|
||||||
|
output = pipeline(input_str)[0]["generated_text"]
|
||||||
|
print(f"Prompt: {input_str}")
|
||||||
|
print(f"Output: {output}")
|
||||||
|
|
||||||
|
save_path = args.save_path
|
||||||
|
if save_path:
|
||||||
|
model.save_low_bit(save_path)
|
||||||
|
tokenizer.save_pretrained(save_path)
|
||||||
|
print(f"Model and tokenizer are saved to {save_path}")
|
||||||
|
|
@ -21,6 +21,13 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
|
def save_low_bit(self, *args, **kwargs):
|
||||||
|
invalidInputError(self.config.to_dict().get("bigdl_transformers_low_bit", False),
|
||||||
|
f"Detected this model is not a low-bit model, please use from_pretrained's"
|
||||||
|
f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.")
|
||||||
|
self.save_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class _BaseAutoModelClass:
|
class _BaseAutoModelClass:
|
||||||
|
|
||||||
HF_MODEL = None
|
HF_MODEL = None
|
||||||
|
|
@ -36,34 +43,76 @@ class _BaseAutoModelClass:
|
||||||
Two new arguments are added to extend Hugging Face's from_pretrained method as follows:
|
Two new arguments are added to extend Hugging Face's from_pretrained method as follows:
|
||||||
New Arguments:
|
New Arguments:
|
||||||
load_in_4bit: boolean value, True means load linear's weight to symmetric int 4.
|
load_in_4bit: boolean value, True means load linear's weight to symmetric int 4.
|
||||||
load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5 or
|
load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5
|
||||||
sym_int8. The model's linear will be loaded into corresponding
|
or sym_int8. (sym_int4 means symmetric int 4, asym_int4 means
|
||||||
low-bit type. sym_int4 means symmetric int 4, asym_int4 means
|
asymmetric int 4, etc.). Relevant low bit optimizations will
|
||||||
asymmetric int 4.
|
be applied to the model.
|
||||||
"""
|
"""
|
||||||
|
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
||||||
|
if len(args) == 0 else args[0]
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
||||||
|
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
|
||||||
|
invalidInputError(not bigdl_transformers_low_bit,
|
||||||
|
f"Detected model is a low-bit({bigdl_transformers_low_bit}) model, "
|
||||||
|
f"Please use load_low_bit to load this model.")
|
||||||
|
|
||||||
# For huggingface transformers cls.HF_Model.from_pretrained could only restore the model
|
# For huggingface transformers cls.HF_Model.from_pretrained could only restore the model
|
||||||
# in the original format, which is not quantized,
|
# in the original format, which is not quantized,
|
||||||
# we can convert the model to quantized later.
|
# we can convert the model to quantized later.
|
||||||
model = None
|
|
||||||
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
||||||
load_in_low_bit = kwargs.pop("load_in_low_bit", None)
|
load_in_low_bit = kwargs.pop("load_in_low_bit", None)
|
||||||
|
|
||||||
|
if load_in_4bit or load_in_low_bit:
|
||||||
|
# load int x-bit
|
||||||
|
kwargs["low_cpu_mem_usage"] = True
|
||||||
|
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
||||||
|
model = cls.load_convert(q_k, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
# load default
|
||||||
|
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_convert(cls, q_k, *args, **kwargs):
|
||||||
|
from .convert import ggml_convert_quant
|
||||||
|
invalidInputError(q_k in ggml_tensor_qtype,
|
||||||
|
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
||||||
|
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
||||||
|
qtype = ggml_tensor_qtype[q_k]
|
||||||
|
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||||
|
model = model.to("cpu")
|
||||||
|
model = ggml_convert_quant(model, qtype)
|
||||||
|
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||||
|
|
||||||
|
# add save_low_bit to pretrained model dynamically
|
||||||
|
import types
|
||||||
|
model.save_low_bit = types.MethodType(save_low_bit, model)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_low_bit(cls,
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
# Read bigdl_transformers_low_bit from config.json
|
# Read bigdl_transformers_low_bit from config.json
|
||||||
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
||||||
if len(args) == 0 else args[0]
|
if len(args) == 0 else args[0]
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
||||||
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
|
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
|
||||||
|
|
||||||
if load_in_4bit or load_in_low_bit or bigdl_transformers_low_bit:
|
invalidInputError(bigdl_transformers_low_bit,
|
||||||
|
"Detect this model is not a low-bit model, Please use from_pretrained"
|
||||||
|
" with load_in_4bit or load_in_low_bit to get a low-bit model , and "
|
||||||
|
" serialize the model using save_low_bit first.")
|
||||||
|
|
||||||
|
invalidInputError(bigdl_transformers_low_bit in ggml_tensor_qtype,
|
||||||
|
f"Unknown bigdl_transformers_low_bit value: {bigdl_transformers_low_bit},"
|
||||||
|
f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
||||||
|
|
||||||
# Speed up when loading model
|
# Speed up when loading model
|
||||||
kwargs["low_cpu_mem_usage"] = True
|
kwargs["low_cpu_mem_usage"] = True
|
||||||
|
|
||||||
if bigdl_transformers_low_bit:
|
|
||||||
invalidInputError(bigdl_transformers_low_bit in ggml_tensor_qtype,
|
|
||||||
f"Unknown bigdl_transformers_low_bit value:"
|
|
||||||
f" {bigdl_transformers_low_bit},"
|
|
||||||
f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
|
||||||
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
|
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
|
||||||
# Note that the int4 linear layers cannot currently
|
# Note that the int4 linear layers cannot currently
|
||||||
# be recorded in huggingface Pretrained Model or AutoConfig,
|
# be recorded in huggingface Pretrained Model or AutoConfig,
|
||||||
|
|
@ -87,6 +136,10 @@ class _BaseAutoModelClass:
|
||||||
"they can be safely ignored; "
|
"they can be safely ignored; "
|
||||||
"the model will be loaded with INT4 optimizations applied.")
|
"the model will be loaded with INT4 optimizations applied.")
|
||||||
|
|
||||||
|
# add save_low_bit to pretrained model dynamically
|
||||||
|
import types
|
||||||
|
model.save_low_bit = types.MethodType(save_low_bit, model)
|
||||||
|
|
||||||
# We forcefully modify the model's definition
|
# We forcefully modify the model's definition
|
||||||
# and the tensor shape of int4 weights without quantization.
|
# and the tensor shape of int4 weights without quantization.
|
||||||
model = ggml_convert_quant(model, qtype, convert_shape_only=True)
|
model = ggml_convert_quant(model, qtype, convert_shape_only=True)
|
||||||
|
|
@ -98,23 +151,6 @@ class _BaseAutoModelClass:
|
||||||
load(model, state_dict)
|
load(model, state_dict)
|
||||||
del state_dict
|
del state_dict
|
||||||
|
|
||||||
elif load_in_4bit or load_in_low_bit:
|
|
||||||
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
|
||||||
model = cls.convert_quant(model, q_k, *args, **kwargs)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_quant(cls, model, q_k, *args, **kwargs):
|
|
||||||
from .convert import ggml_convert_quant
|
|
||||||
invalidInputError(q_k in ggml_tensor_qtype,
|
|
||||||
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
|
||||||
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
|
||||||
qtype = ggml_tensor_qtype[q_k]
|
|
||||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
|
||||||
model = model.to("cpu")
|
|
||||||
model = ggml_convert_quant(model, qtype)
|
|
||||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,15 @@ class TestConvertModel(TestCase):
|
||||||
model = AutoModelForCausalLM.from_pretrained(llama_model_path,
|
model = AutoModelForCausalLM.from_pretrained(llama_model_path,
|
||||||
load_in_low_bit="sym_int8")
|
load_in_low_bit="sym_int8")
|
||||||
|
|
||||||
|
def test_transformer_convert_llama_save_load(self):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(llama_model_path,
|
||||||
|
load_in_low_bit="asym_int4")
|
||||||
|
tempdir = tempfile.mkdtemp(dir=output_dir)
|
||||||
|
model.save_low_bit(tempdir)
|
||||||
|
newModel = AutoModelForCausalLM.load_low_bit(tempdir)
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(tempdir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue