LLM: Implement hf low_cpu_mem_usage with 1xbinary file peak memory on transformer int4 (#8731)

* 1x peak memory
This commit is contained in:
Zhao Changmin 2023-08-29 09:33:17 +08:00 committed by GitHub
parent 5d90ca2dac
commit bb31d4fe80
3 changed files with 203 additions and 137 deletions

View file

@ -44,27 +44,10 @@ import importlib
def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
current_key_name=None, convert_shape_only=False):
from bigdl.llm.transformers.linear_quant import LinearQuant, ParamsQuant
current_key_name=None):
from bigdl.llm.transformers.linear_quant import LinearQuant, FP4Params
has_been_replaced = False
# Through our method, certain layers that were initialized on the device "meta"
# (associated with the lazy initialization strategy of low_cpu_mem_usage) are not
# being correctly moved back to the CPU device for some reason. Therefore, we are
# moving these layers back to the CPU here in order to prevent the occurrence
# of NoImplementnError. Details refer to:
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3110
model_state_dict = model.state_dict()
for name, param in model.named_parameters():
if param.data.device == torch.device('meta'):
from accelerate.utils.modeling import set_module_tensor_to_device
param = model_state_dict[name]
set_module_tensor_to_device(model,
name,
"cpu",
torch.empty(*param.size(), dtype=torch.float32))
del model_state_dict
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
@ -80,17 +63,18 @@ def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
module.bias is not None,
)
device_type = module.weight.data.device.type
# Copy the weights
paramsQuant = ParamsQuant(data=module.weight.data,
requires_grad=False,
quantized=False,
convert_shape_only=convert_shape_only,
_shape=None,
qtype=qtype).to("cpu")
paramsQuant = FP4Params(data=module.weight.data,
requires_grad=False,
quantized=False,
_shape=None,
qtype=qtype).to(device_type)
new_linear._parameters['weight'] = paramsQuant
if module.bias is not None:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data).to("cpu")
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device_type)
model._modules[name] = new_linear
has_been_replaced = True
@ -106,15 +90,14 @@ def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
qtype,
modules_to_not_convert,
current_key_name,
convert_shape_only,
)
return model, has_been_replaced
def ggml_convert_quant(model, qtype, optimize_model=True, convert_shape_only=False):
def ggml_convert_quant(model, qtype, optimize_model=True, device="cpu"):
modules_to_not_convert = [] # ["lm_head"]
model, has_been_replaced = _replace_with_quant_linear(
model, qtype, modules_to_not_convert, None, convert_shape_only=convert_shape_only
model, qtype, modules_to_not_convert, None
)
if not has_been_replaced:
warnings.warn(
@ -123,8 +106,11 @@ def ggml_convert_quant(model, qtype, optimize_model=True, convert_shape_only=Fal
"instead of Linear layers. Please double check your model architecture, or submit "
"an issue on github if you think this is a bug."
)
else:
elif device == "cpu":
model.to(torch.float32)
elif device == "meta":
# Do nothing here for weights are empty.
pass
if optimize_model:
model = optimize(model)

View file

@ -59,7 +59,7 @@ TORCH_LINEAR_THRESHOLD = 96
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
def ggml_convert_quant(tensor: torch.Tensor, qtype: int, convert_shape_only=False):
def ggml_convert_quant(tensor: torch.Tensor, qtype: int, device=None):
QK = ggml.ggml_qk_size(qtype)
block_size_in_bytes = ggml.ggml_type_size(qtype)
@ -75,12 +75,12 @@ def ggml_convert_quant(tensor: torch.Tensor, qtype: int, convert_shape_only=Fals
"Last dim of input tensor must be multiple of 64")
dst_size = (n // QK) * block_size_in_bytes
dst_tensor = torch.empty(dst_size, dtype=torch.uint8)
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
device=device)
hist = (ctypes.c_int64 * 16)()
if not convert_shape_only:
if device != 'meta':
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
hist = (ctypes.c_int64 * 16)()
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
return dst_tensor
@ -98,14 +98,16 @@ def ggml_int4_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int):
return dst_tensor
class ParamsQuant(torch.nn.Parameter):
# Rename to FP4Params to trigger initializing
# the params layer with all parameters on the CPU
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333
class FP4Params(torch.nn.Parameter):
def __new__(cls,
data=None,
requires_grad=True,
requires_grad=False,
old_data=None,
quantized=False,
_shape=None,
convert_shape_only=False,
qtype=None):
if data is None:
data = torch.empty(0)
@ -114,16 +116,14 @@ class ParamsQuant(torch.nn.Parameter):
self.data = data
self.quantized = quantized
self._shape = _shape
self.convert_shape_only = convert_shape_only
self.qtype = qtype
return self
def quantize(self, device):
def quantize(self, device=None):
if not self.quantized:
w = self.data.contiguous().float()
# self.old_data = self.data
w_quantized = ggml_convert_quant(w, self.qtype,
convert_shape_only=self.convert_shape_only)
device=device)
self.data = w_quantized
self.quantized = True
self._shape = w.shape
@ -147,28 +147,29 @@ class ParamsQuant(torch.nn.Parameter):
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if (device is not None and device.type == "cpu" and self.data.device.type == "cpu"):
return self.quantize(device)
return self.quantize(device.type)
elif device is not None and device.type == "meta" and self.data.device.type == "meta":
return self.quantize(device.type)
elif (device is not None and device.type == "xpu" and self.data.device.type == "cpu"):
# enter xpu logic, compile linear_int4 extension at first time
q_tensor = self.quantize(device) # tensor is cpu now
new_param = ParamsQuant(super().to(device=device,
dtype=dtype,
non_blocking=non_blocking),
requires_grad=self.requires_grad,
quantized=self.quantized,
_shape=self._shape,
qtype=self.qtype)
new_param = FP4Params(super().to(device=device,
dtype=dtype,
non_blocking=non_blocking),
requires_grad=self.requires_grad,
quantized=self.quantized,
_shape=self._shape,
qtype=self.qtype)
return new_param
else:
new_param = ParamsQuant(super().to(device=device,
dtype=dtype,
non_blocking=non_blocking),
requires_grad=self.requires_grad,
quantized=self.quantized,
_shape=self._shape,
qtype=self.qtype)
new_param = FP4Params(super().to(device=device,
dtype=dtype,
non_blocking=non_blocking),
requires_grad=self.requires_grad,
quantized=self.quantized,
_shape=self._shape,
qtype=self.qtype)
return new_param
@ -213,9 +214,10 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor,
class LinearQuant(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True):
super().__init__(input_features, output_features, bias)
self.weight = ParamsQuant(self.weight.data, requires_grad=False,
old_data=self.weight.data,
quantized=False, _shape=None, qtype=qtype)
self.weight = FP4Params(self.weight.data,
requires_grad=False,
old_data=self.weight.data,
quantized=False, _shape=None, qtype=qtype)
self.in_len = input_features
self.out_len = output_features
self.weight_shape = (self.out_len, self.in_len)
@ -223,7 +225,6 @@ class LinearQuant(nn.Linear):
self.qtype = qtype
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)

View file

@ -14,18 +14,14 @@
# limitations under the License.
#
import gc
import transformers
from transformers.configuration_utils import PretrainedConfig
from .utils import extract_local_archive_file, \
load_state_dict, \
load, \
get_local_shard_files, \
fix_key
get_local_shard_files
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError, MuteHFLogger
import sys
import importlib
from bigdl.llm.utils.common import invalidInputError
import torch
def save_low_bit(self, *args, **kwargs):
@ -33,6 +29,15 @@ def save_low_bit(self, *args, **kwargs):
f"Detected this model is not a low-bit model, please use from_pretrained's"
f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.")
self.save_pretrained(*args, **kwargs)
import json
import os
# We conveniently save all the keys of the model to have them on hand,
# so that when using 'low_cpumem load',
# it's not necessary to load the entire model to extract its keys
# and we can avoid gc not triggered potentially.
load_keys = {"all_checkpoint_keys": list(self.state_dict().keys())}
with open(os.path.join(args[0], "load_keys.json"), "w") as json_file:
json.dump(load_keys, json_file)
class _BaseAutoModelClass:
@ -106,11 +111,44 @@ class _BaseAutoModelClass:
@classmethod
def load_low_bit(cls,
*args,
pretrained_model_name_or_path,
*model_args,
**kwargs):
# Read bigdl_transformers_low_bit from config.json
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
if len(args) == 0 else args[0]
from transformers.modeling_utils import no_init_weights, get_state_dict_dtype
from transformers.dynamic_module_utils import resolve_trust_remote_code, \
get_class_from_dynamic_module
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.utils.generic import ContextManagers
from transformers.generation.configuration_utils import GenerationConfig
from transformers.models.auto.auto_factory import _get_model_class
from accelerate.big_modeling import init_empty_weights
from .convert import ggml_convert_quant
import copy
import os
# Autofactory
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs_orig = copy.deepcopy(kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
return_unused_kwargs=True,
trust_remote_code=trust_remote_code,
**kwargs,
)
# if torch_dtype=auto was passed here, ensure to pass it on
if kwargs_orig.get("torch_dtype", None) == "auto":
kwargs["torch_dtype"] = "auto"
# Maybe needed when extract_local_archive_file
subfolder = kwargs.get("subfolder", "")
variant = kwargs.get("variant", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
torch_dtype = kwargs.pop("torch_dtype", "auto")
sharded_metadata = None
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False)
@ -123,89 +161,130 @@ class _BaseAutoModelClass:
f"Unknown bigdl_transformers_low_bit value: {bigdl_transformers_low_bit},"
f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
# Speed up when loading model
kwargs["low_cpu_mem_usage"] = True
# set default torch_dtype='auto'
kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto')
# set default optimize_model=True
optimize_model = kwargs.pop("optimize_model", True)
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
# Note that the int4 linear layers cannot currently
# be recorded in huggingface Pretrained Model or AutoConfig,
# and huggingface transformers cls.HF_Model.from_pretrained
# could only restore the model in the original format,
# which is not quantized. we can Initialize original model first,
# convert the model to quantized int4 format later, and then load the quantized model.
# Avoid KeyError
kwargs["ignore_mismatched_sizes"] = True
has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map
has_local_code = type(config) in cls.HF_Model._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
if has_remote_code and trust_remote_code:
class_ref = config.auto_map[cls.HF_Model.__name__]
model_class = get_class_from_dynamic_module(
class_ref, pretrained_model_name_or_path, **kwargs
)
if os.path.isdir(pretrained_model_name_or_path):
model_class.register_for_auto_class(cls.HF_Model.__name__)
else:
cls.HF_Model.register(config.__class__, model_class, exist_ok=True)
elif type(config) in cls.HF_Model._model_mapping.keys():
model_class = _get_model_class(config, cls.HF_Model._model_mapping)
# Maybe needed when extract_local_archive_file
subfolder = kwargs.get("subfolder", "")
variant = kwargs.get("variant", None)
from .convert import ggml_convert_quant
with MuteHFLogger(logger=transformers.modeling_utils.logger):
model = cls.HF_Model.from_pretrained(*args, **kwargs)
# add save_low_bit to pretrained model dynamically
import types
model.save_low_bit = types.MethodType(save_low_bit, model)
# We forcefully modify the model's definition
# and the tensor shape of int4 weights without quantization.
model = ggml_convert_quant(model, qtype, optimize_model, convert_shape_only=True)
# Load the quantized model at last.
resolved_archive_file, is_sharded = extract_local_archive_file(
pretrained_model_name_or_path,
subfolder,
variant)
if is_sharded:
resolved_archive_file, sharded_metadata = \
get_local_shard_files(pretrained_model_name_or_path,
resolved_archive_file,
subfolder=subfolder)
start_prefix = ""
prefix = model.base_model_prefix
loaded_keys = [fix_key(key) for key in sharded_metadata["all_checkpoint_keys"]]
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
else:
has_prefix_module = False
model_cls = type(model)
if len(model_cls.base_model_prefix) > 0 and \
not hasattr(model, model_cls.base_model_prefix) and \
has_prefix_module:
start_prefix = model_cls.base_model_prefix + "."
from transformers.modeling_utils import _load_state_dict_into_model
error_msgs = []
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
error_msgs += _load_state_dict_into_model(model, state_dict, start_prefix)
# force memory release
del state_dict
gc.collect()
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict,
# by checking its first weights entry that is of a floating type
# - we assume all floating dtype weights are of the same dtype
# we also may have config.torch_dtype available, but we won't rely on it till v5
dtype_orig = None
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:
error_msg += (
"\n\tYou may consider adding `ignore_mismatched_sizes=True`"
" in the model `from_pretrained` method."
)
invalidInputError(False, "Error(s) in loading state_dict"
f"for {model.__class__.__name__}:\n\t{error_msg}")
if torch_dtype is not None:
if isinstance(torch_dtype, str):
if torch_dtype == "auto":
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
torch_dtype = config.torch_dtype
else:
if is_sharded and "dtype" in sharded_metadata:
torch_dtype = sharded_metadata["dtype"]
else:
one_state_dict = load_state_dict(resolved_archive_file[0])
torch_dtype = get_state_dict_dtype(one_state_dict)
del one_state_dict # free CPU memory
else:
invalidInputError(False,
f'`torch_dtype` can be either `torch.dtype` or `"auto"`,'
'but received {torch_dtype}')
dtype_orig = model_class._set_default_torch_dtype(torch_dtype)
# Pretrained Model
_fast_init = kwargs.pop("_fast_init", True)
init_contexts = [no_init_weights(_enable=_fast_init)]
init_contexts.append(init_empty_weights())
with ContextManagers(init_contexts):
model = model_class(config, *model_args, **kwargs)
model = ggml_convert_quant(model, qtype, optimize_model, device="meta")
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
state_dict = load_state_dict(resolved_archive_file)
load(model, state_dict)
del state_dict
import os
import json
with open(os.path.join(pretrained_model_name_or_path,
"load_keys.json"), "r") as json_file:
loaded_data = json.load(json_file)
loaded_state_dict_keys = loaded_data["all_checkpoint_keys"]
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = model_class._load_pretrained_model(
model,
None,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=True,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)
# make sure token embedding weights are still tied if needed
model.tie_weights()
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate():
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder=subfolder,
**kwargs,
)
except (OSError, TypeError):
pass
for param in model.parameters():
param.requires_grad_(False)
return model