LLM: Optimize transformer int4 loading (#8499)
* LLM: Optimize transformer int4 loading
This commit is contained in:
parent
dd3f953288
commit
23f6a4c21f
3 changed files with 69 additions and 37 deletions
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
# Some parts of this file is adapted from
|
# Some parts of this file is adapted from
|
||||||
# https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/utils/bitsandbytes.py
|
# https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/utils/bitsandbytes.py
|
||||||
|
# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py
|
||||||
# which is licensed under Apache License 2.0:
|
# which is licensed under Apache License 2.0:
|
||||||
#
|
#
|
||||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
|
@ -44,6 +45,23 @@ import warnings
|
||||||
def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
|
def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
|
||||||
current_key_name=None, convert_shape_only=False):
|
current_key_name=None, convert_shape_only=False):
|
||||||
has_been_replaced = False
|
has_been_replaced = False
|
||||||
|
|
||||||
|
# Through our method, certain layers that were initialized on the device "meta"
|
||||||
|
# (associated with the lazy initialization strategy of low_cpu_mem_usage) are not
|
||||||
|
# being correctly moved back to the CPU device for some reason. Therefore, we are
|
||||||
|
# moving these layers back to the CPU here in order to prevent the occurrence
|
||||||
|
# of NoImplementnError. Details refer to:
|
||||||
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3110
|
||||||
|
model_state_dict = model.state_dict()
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.data.device == torch.device('meta'):
|
||||||
|
from accelerate.utils.modeling import set_module_tensor_to_device
|
||||||
|
param = model_state_dict[name]
|
||||||
|
set_module_tensor_to_device(model,
|
||||||
|
name,
|
||||||
|
"cpu",
|
||||||
|
torch.empty(*param.size(), dtype=torch.float32))
|
||||||
|
|
||||||
for name, module in model.named_children():
|
for name, module in model.named_children():
|
||||||
if current_key_name is None:
|
if current_key_name is None:
|
||||||
current_key_name = []
|
current_key_name = []
|
||||||
|
|
@ -86,6 +104,7 @@ def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
|
||||||
qtype,
|
qtype,
|
||||||
modules_to_not_convert,
|
modules_to_not_convert,
|
||||||
current_key_name,
|
current_key_name,
|
||||||
|
convert_shape_only,
|
||||||
)
|
)
|
||||||
return model, has_been_replaced
|
return model, has_been_replaced
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,50 +29,54 @@ class _BaseAutoModelClass:
|
||||||
def from_pretrained(cls,
|
def from_pretrained(cls,
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
|
||||||
qtype = 0
|
|
||||||
if load_in_4bit:
|
|
||||||
kwargs["low_cpu_mem_usage"] = True
|
|
||||||
qtype = ggml_tensor_qtype['q4_0']
|
|
||||||
load_in_low_bit = kwargs.pop("load_in_low_bit", "").lower()
|
|
||||||
if load_in_low_bit:
|
|
||||||
kwargs["low_cpu_mem_usage"] = True
|
|
||||||
invalidInputError(qtype in ggml_tensor_qtype,
|
|
||||||
f"Unknown load_in_low_bit value: {qtype},"
|
|
||||||
f" excepted q4_0, q4_1, q5_0, q5_1, q8_0.")
|
|
||||||
qtype = ggml_tensor_qtype[load_in_low_bit]
|
|
||||||
|
|
||||||
subfolder = kwargs.get("subfolder", "")
|
|
||||||
variant = kwargs.get("variant", None)
|
|
||||||
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
|
||||||
if len(args) == 0 else args[0]
|
|
||||||
|
|
||||||
# For huggingface transformers cls.HF_Model.from_pretrained could only restore the model
|
# For huggingface transformers cls.HF_Model.from_pretrained could only restore the model
|
||||||
# in the original format, which is not quantized,
|
# in the original format, which is not quantized,
|
||||||
# we can convert the model to quantized later.
|
# we can convert the model to quantized later.
|
||||||
model = None
|
model = None
|
||||||
|
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
||||||
|
load_in_low_bit = kwargs.pop("load_in_low_bit", None)
|
||||||
|
|
||||||
# Read bigdl_transformers_int4 from config.json
|
# Read bigdl_transformers_low_bit from config.json
|
||||||
|
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
||||||
|
if len(args) == 0 else args[0]
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
||||||
|
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
|
||||||
|
|
||||||
|
if load_in_4bit or load_in_low_bit or bigdl_transformers_low_bit:
|
||||||
|
# Speed up when loading model
|
||||||
|
kwargs["low_cpu_mem_usage"] = True
|
||||||
|
|
||||||
|
if bigdl_transformers_low_bit:
|
||||||
|
invalidInputError(bigdl_transformers_low_bit in ggml_tensor_qtype,
|
||||||
|
f"Unknown load_in_low_bit value: {bigdl_transformers_low_bit},"
|
||||||
|
f" excepted q4_0, q4_1, q5_0, q5_1, q8_0.")
|
||||||
|
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
|
||||||
|
# Note that the int4 linear layers cannot currently
|
||||||
|
# be recorded in huggingface Pretrained Model or AutoConfig,
|
||||||
|
# and huggingface transformers cls.HF_Model.from_pretrained
|
||||||
|
# could only restore the model in the original format,
|
||||||
|
# which is not quantized. we can Initialize original model first,
|
||||||
|
# convert the model to quantized int4 format later, and then load the quantized model.
|
||||||
|
|
||||||
bigdl_transformers_int4 = config_dict.pop("bigdl_transformers_int4", False)
|
|
||||||
if bigdl_transformers_int4:
|
|
||||||
# Avoid KeyError
|
# Avoid KeyError
|
||||||
kwargs["ignore_mismatched_sizes"] = True
|
kwargs["ignore_mismatched_sizes"] = True
|
||||||
|
# Avoid reading from local file at the first initialization
|
||||||
|
kwargs["state_dict"] = {}
|
||||||
|
|
||||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
# Maybe needed when extract_local_archive_file
|
||||||
print("Note: If there are warnings about mismatched during the loading process, "
|
subfolder = kwargs.get("subfolder", "")
|
||||||
"please ignore them as it is part of the normal flow. "
|
variant = kwargs.get("variant", None)
|
||||||
"The model will be reconverted to the format of BigDL after loading.")
|
|
||||||
|
|
||||||
# Note that the ggml_matmul_src1_x_src0_t operation cannot currently
|
|
||||||
# be recorded in AutoConfig,
|
|
||||||
# and this operation is not included in the core Hugging Face infrastructure.
|
|
||||||
if bigdl_transformers_int4:
|
|
||||||
from .convert import ggml_convert_quant
|
from .convert import ggml_convert_quant
|
||||||
|
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||||
|
print("Note: If there are warnings during the model loading process, "
|
||||||
|
"they can be safely ignored; "
|
||||||
|
"the model will be loaded with INT4 optimizations applied.")
|
||||||
|
|
||||||
# We forcefully modify the model's definition
|
# We forcefully modify the model's definition
|
||||||
# and the tensor shape of int4 weights without quantization.
|
# and the tensor shape of int4 weights without quantization.
|
||||||
model = ggml_convert_quant(model, convert_shape_only=True)
|
model = ggml_convert_quant(model, qtype, convert_shape_only=True)
|
||||||
# Load the quantized model at last.
|
# Load the quantized model at last.
|
||||||
archive_file = extract_local_archive_file(pretrained_model_name_or_path,
|
archive_file = extract_local_archive_file(pretrained_model_name_or_path,
|
||||||
subfolder,
|
subfolder,
|
||||||
|
|
@ -80,12 +84,24 @@ class _BaseAutoModelClass:
|
||||||
state_dict = load_state_dict(archive_file)
|
state_dict = load_state_dict(archive_file)
|
||||||
load(model, state_dict)
|
load(model, state_dict)
|
||||||
del state_dict
|
del state_dict
|
||||||
elif qtype:
|
|
||||||
|
elif load_in_4bit or load_in_low_bit:
|
||||||
|
q_k = load_in_low_bit if load_in_low_bit else "q4_0"
|
||||||
|
model = cls.convert_quant(model, q_k, *args, **kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_quant(cls, model, q_k, *args, **kwargs):
|
||||||
from .convert import ggml_convert_quant
|
from .convert import ggml_convert_quant
|
||||||
|
invalidInputError(q_k in ggml_tensor_qtype,
|
||||||
|
f"Unknown load_in_low_bit value: {q_k},"
|
||||||
|
f" excepted q4_0, q4_1, q5_0, q5_1, q8_0.")
|
||||||
|
qtype = ggml_tensor_qtype[q_k]
|
||||||
|
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||||
model = model.to("cpu")
|
model = model.to("cpu")
|
||||||
model = ggml_convert_quant(model, qtype)
|
model = ggml_convert_quant(model, qtype)
|
||||||
model.config.update({"bigdl_transformers_int4": True})
|
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -52,9 +52,6 @@ WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
|
|
||||||
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant):
|
def extract_local_archive_file(pretrained_model_name_or_path, subfolder, variant):
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
print(os.path.join(pretrained_model_name_or_path,
|
|
||||||
subfolder,
|
|
||||||
_add_variant(WEIGHTS_NAME, variant)))
|
|
||||||
if os.path.isfile(
|
if os.path.isfile(
|
||||||
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
|
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
|
||||||
):
|
):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue