diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 5799b3f4..72624549 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -44,27 +44,10 @@ import importlib def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None, - current_key_name=None, convert_shape_only=False): - from bigdl.llm.transformers.linear_quant import LinearQuant, ParamsQuant + current_key_name=None): + from bigdl.llm.transformers.linear_quant import LinearQuant, FP4Params has_been_replaced = False - # Through our method, certain layers that were initialized on the device "meta" - # (associated with the lazy initialization strategy of low_cpu_mem_usage) are not - # being correctly moved back to the CPU device for some reason. Therefore, we are - # moving these layers back to the CPU here in order to prevent the occurrence - # of NoImplementnError. Details refer to: - # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3110 - model_state_dict = model.state_dict() - for name, param in model.named_parameters(): - if param.data.device == torch.device('meta'): - from accelerate.utils.modeling import set_module_tensor_to_device - param = model_state_dict[name] - set_module_tensor_to_device(model, - name, - "cpu", - torch.empty(*param.size(), dtype=torch.float32)) - del model_state_dict - for name, module in model.named_children(): if current_key_name is None: current_key_name = [] @@ -80,17 +63,18 @@ def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None, module.bias is not None, ) + device_type = module.weight.data.device.type # Copy the weights - paramsQuant = ParamsQuant(data=module.weight.data, - requires_grad=False, - quantized=False, - convert_shape_only=convert_shape_only, - _shape=None, - qtype=qtype).to("cpu") + paramsQuant = FP4Params(data=module.weight.data, + requires_grad=False, + quantized=False, + _shape=None, + qtype=qtype).to(device_type) new_linear._parameters['weight'] = paramsQuant if module.bias is not None: - new_linear._parameters['bias'] = nn.Parameter(module.bias.data).to("cpu") + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device_type) model._modules[name] = new_linear has_been_replaced = True @@ -106,15 +90,14 @@ def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None, qtype, modules_to_not_convert, current_key_name, - convert_shape_only, ) return model, has_been_replaced -def ggml_convert_quant(model, qtype, optimize_model=True, convert_shape_only=False): +def ggml_convert_quant(model, qtype, optimize_model=True, device="cpu"): modules_to_not_convert = [] # ["lm_head"] model, has_been_replaced = _replace_with_quant_linear( - model, qtype, modules_to_not_convert, None, convert_shape_only=convert_shape_only + model, qtype, modules_to_not_convert, None ) if not has_been_replaced: warnings.warn( @@ -123,8 +106,11 @@ def ggml_convert_quant(model, qtype, optimize_model=True, convert_shape_only=Fal "instead of Linear layers. Please double check your model architecture, or submit " "an issue on github if you think this is a bug." ) - else: + elif device == "cpu": model.to(torch.float32) + elif device == "meta": + # Do nothing here for weights are empty. + pass if optimize_model: model = optimize(model) diff --git a/python/llm/src/bigdl/llm/transformers/linear_quant.py b/python/llm/src/bigdl/llm/transformers/linear_quant.py index cb0d6cbc..f7b4bab1 100644 --- a/python/llm/src/bigdl/llm/transformers/linear_quant.py +++ b/python/llm/src/bigdl/llm/transformers/linear_quant.py @@ -59,7 +59,7 @@ TORCH_LINEAR_THRESHOLD = 96 SYM_INT4 = ggml_tensor_qtype["sym_int4"] -def ggml_convert_quant(tensor: torch.Tensor, qtype: int, convert_shape_only=False): +def ggml_convert_quant(tensor: torch.Tensor, qtype: int, device=None): QK = ggml.ggml_qk_size(qtype) block_size_in_bytes = ggml.ggml_type_size(qtype) @@ -75,12 +75,12 @@ def ggml_convert_quant(tensor: torch.Tensor, qtype: int, convert_shape_only=Fals "Last dim of input tensor must be multiple of 64") dst_size = (n // QK) * block_size_in_bytes - dst_tensor = torch.empty(dst_size, dtype=torch.uint8) - dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) + dst_tensor = torch.empty(dst_size, dtype=torch.uint8, + device=device) - hist = (ctypes.c_int64 * 16)() - - if not convert_shape_only: + if device != 'meta': + dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) + hist = (ctypes.c_int64 * 16)() ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist) return dst_tensor @@ -98,14 +98,16 @@ def ggml_int4_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int): return dst_tensor -class ParamsQuant(torch.nn.Parameter): +# Rename to FP4Params to trigger initializing +# the params layer with all parameters on the CPU +# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333 +class FP4Params(torch.nn.Parameter): def __new__(cls, data=None, - requires_grad=True, + requires_grad=False, old_data=None, quantized=False, _shape=None, - convert_shape_only=False, qtype=None): if data is None: data = torch.empty(0) @@ -114,16 +116,14 @@ class ParamsQuant(torch.nn.Parameter): self.data = data self.quantized = quantized self._shape = _shape - self.convert_shape_only = convert_shape_only self.qtype = qtype return self - def quantize(self, device): + def quantize(self, device=None): if not self.quantized: w = self.data.contiguous().float() - # self.old_data = self.data w_quantized = ggml_convert_quant(w, self.qtype, - convert_shape_only=self.convert_shape_only) + device=device) self.data = w_quantized self.quantized = True self._shape = w.shape @@ -147,28 +147,29 @@ class ParamsQuant(torch.nn.Parameter): def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if (device is not None and device.type == "cpu" and self.data.device.type == "cpu"): - return self.quantize(device) + return self.quantize(device.type) + elif device is not None and device.type == "meta" and self.data.device.type == "meta": + return self.quantize(device.type) elif (device is not None and device.type == "xpu" and self.data.device.type == "cpu"): # enter xpu logic, compile linear_int4 extension at first time q_tensor = self.quantize(device) # tensor is cpu now - new_param = ParamsQuant(super().to(device=device, - dtype=dtype, - non_blocking=non_blocking), - requires_grad=self.requires_grad, - quantized=self.quantized, - _shape=self._shape, - qtype=self.qtype) + new_param = FP4Params(super().to(device=device, + dtype=dtype, + non_blocking=non_blocking), + requires_grad=self.requires_grad, + quantized=self.quantized, + _shape=self._shape, + qtype=self.qtype) return new_param else: - new_param = ParamsQuant(super().to(device=device, - dtype=dtype, - non_blocking=non_blocking), - requires_grad=self.requires_grad, - quantized=self.quantized, - _shape=self._shape, - qtype=self.qtype) + new_param = FP4Params(super().to(device=device, + dtype=dtype, + non_blocking=non_blocking), + requires_grad=self.requires_grad, + quantized=self.quantized, + _shape=self._shape, + qtype=self.qtype) return new_param @@ -213,9 +214,10 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor, class LinearQuant(nn.Linear): def __init__(self, input_features, output_features, qtype, bias=True): super().__init__(input_features, output_features, bias) - self.weight = ParamsQuant(self.weight.data, requires_grad=False, - old_data=self.weight.data, - quantized=False, _shape=None, qtype=qtype) + self.weight = FP4Params(self.weight.data, + requires_grad=False, + old_data=self.weight.data, + quantized=False, _shape=None, qtype=qtype) self.in_len = input_features self.out_len = output_features self.weight_shape = (self.out_len, self.in_len) @@ -223,7 +225,6 @@ class LinearQuant(nn.Linear): self.qtype = qtype def forward(self, x: torch.Tensor): - # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 19e8679d..67c60dd5 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -14,18 +14,14 @@ # limitations under the License. # -import gc import transformers from transformers.configuration_utils import PretrainedConfig from .utils import extract_local_archive_file, \ load_state_dict, \ - load, \ - get_local_shard_files, \ - fix_key + get_local_shard_files from bigdl.llm.ggml.quantize import ggml_tensor_qtype -from bigdl.llm.utils.common import invalidInputError, MuteHFLogger -import sys -import importlib +from bigdl.llm.utils.common import invalidInputError +import torch def save_low_bit(self, *args, **kwargs): @@ -33,6 +29,15 @@ def save_low_bit(self, *args, **kwargs): f"Detected this model is not a low-bit model, please use from_pretrained's" f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.") self.save_pretrained(*args, **kwargs) + import json + import os + # We conveniently save all the keys of the model to have them on hand, + # so that when using 'low_cpumem load', + # it's not necessary to load the entire model to extract its keys + # and we can avoid gc not triggered potentially. + load_keys = {"all_checkpoint_keys": list(self.state_dict().keys())} + with open(os.path.join(args[0], "load_keys.json"), "w") as json_file: + json.dump(load_keys, json_file) class _BaseAutoModelClass: @@ -106,11 +111,44 @@ class _BaseAutoModelClass: @classmethod def load_low_bit(cls, - *args, + pretrained_model_name_or_path, + *model_args, **kwargs): - # Read bigdl_transformers_low_bit from config.json - pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \ - if len(args) == 0 else args[0] + from transformers.modeling_utils import no_init_weights, get_state_dict_dtype + from transformers.dynamic_module_utils import resolve_trust_remote_code, \ + get_class_from_dynamic_module + from transformers.models.auto.configuration_auto import AutoConfig + from transformers.utils.generic import ContextManagers + from transformers.generation.configuration_utils import GenerationConfig + from transformers.models.auto.auto_factory import _get_model_class + from accelerate.big_modeling import init_empty_weights + from .convert import ggml_convert_quant + import copy + import os + + # Autofactory + trust_remote_code = kwargs.pop("trust_remote_code", None) + kwargs_orig = copy.deepcopy(kwargs) + + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + return_unused_kwargs=True, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + # if torch_dtype=auto was passed here, ensure to pass it on + if kwargs_orig.get("torch_dtype", None) == "auto": + kwargs["torch_dtype"] = "auto" + + # Maybe needed when extract_local_archive_file + subfolder = kwargs.get("subfolder", "") + variant = kwargs.get("variant", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + torch_dtype = kwargs.pop("torch_dtype", "auto") + sharded_metadata = None + config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path) bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False) @@ -123,89 +161,130 @@ class _BaseAutoModelClass: f"Unknown bigdl_transformers_low_bit value: {bigdl_transformers_low_bit}," f" expected: sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.") - # Speed up when loading model - kwargs["low_cpu_mem_usage"] = True - - # set default torch_dtype='auto' - kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto') - # set default optimize_model=True optimize_model = kwargs.pop("optimize_model", True) qtype = ggml_tensor_qtype[bigdl_transformers_low_bit] - # Note that the int4 linear layers cannot currently - # be recorded in huggingface Pretrained Model or AutoConfig, - # and huggingface transformers cls.HF_Model.from_pretrained - # could only restore the model in the original format, - # which is not quantized. we can Initialize original model first, - # convert the model to quantized int4 format later, and then load the quantized model. - # Avoid KeyError - kwargs["ignore_mismatched_sizes"] = True + has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map + has_local_code = type(config) in cls.HF_Model._model_mapping.keys() + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code + ) + if has_remote_code and trust_remote_code: + class_ref = config.auto_map[cls.HF_Model.__name__] + model_class = get_class_from_dynamic_module( + class_ref, pretrained_model_name_or_path, **kwargs + ) + if os.path.isdir(pretrained_model_name_or_path): + model_class.register_for_auto_class(cls.HF_Model.__name__) + else: + cls.HF_Model.register(config.__class__, model_class, exist_ok=True) + elif type(config) in cls.HF_Model._model_mapping.keys(): + model_class = _get_model_class(config, cls.HF_Model._model_mapping) - # Maybe needed when extract_local_archive_file - subfolder = kwargs.get("subfolder", "") - variant = kwargs.get("variant", None) - - from .convert import ggml_convert_quant - - with MuteHFLogger(logger=transformers.modeling_utils.logger): - model = cls.HF_Model.from_pretrained(*args, **kwargs) - - # add save_low_bit to pretrained model dynamically - import types - model.save_low_bit = types.MethodType(save_low_bit, model) - - # We forcefully modify the model's definition - # and the tensor shape of int4 weights without quantization. - model = ggml_convert_quant(model, qtype, optimize_model, convert_shape_only=True) - # Load the quantized model at last. resolved_archive_file, is_sharded = extract_local_archive_file( pretrained_model_name_or_path, subfolder, variant) + if is_sharded: resolved_archive_file, sharded_metadata = \ get_local_shard_files(pretrained_model_name_or_path, resolved_archive_file, subfolder=subfolder) - start_prefix = "" - prefix = model.base_model_prefix - loaded_keys = [fix_key(key) for key in sharded_metadata["all_checkpoint_keys"]] - if len(prefix) > 0: - has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) - else: - has_prefix_module = False - model_cls = type(model) - if len(model_cls.base_model_prefix) > 0 and \ - not hasattr(model, model_cls.base_model_prefix) and \ - has_prefix_module: - start_prefix = model_cls.base_model_prefix + "." - from transformers.modeling_utils import _load_state_dict_into_model - error_msgs = [] - for shard_file in resolved_archive_file: - state_dict = load_state_dict(shard_file) - error_msgs += _load_state_dict_into_model(model, state_dict, start_prefix) - # force memory release - del state_dict - gc.collect() + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, + # by checking its first weights entry that is of a floating type + # - we assume all floating dtype weights are of the same dtype + # we also may have config.torch_dtype available, but we won't rely on it till v5 + dtype_orig = None - if len(error_msgs) > 0: - error_msg = "\n\t".join(error_msgs) - if "size mismatch" in error_msg: - error_msg += ( - "\n\tYou may consider adding `ignore_mismatched_sizes=True`" - " in the model `from_pretrained` method." - ) - invalidInputError(False, "Error(s) in loading state_dict" - f"for {model.__class__.__name__}:\n\t{error_msg}") + if torch_dtype is not None: + if isinstance(torch_dtype, str): + if torch_dtype == "auto": + if hasattr(config, "torch_dtype") and config.torch_dtype is not None: + torch_dtype = config.torch_dtype + else: + if is_sharded and "dtype" in sharded_metadata: + torch_dtype = sharded_metadata["dtype"] + else: + one_state_dict = load_state_dict(resolved_archive_file[0]) + torch_dtype = get_state_dict_dtype(one_state_dict) + del one_state_dict # free CPU memory + else: + invalidInputError(False, + f'`torch_dtype` can be either `torch.dtype` or `"auto"`,' + 'but received {torch_dtype}') + dtype_orig = model_class._set_default_torch_dtype(torch_dtype) + + # Pretrained Model + _fast_init = kwargs.pop("_fast_init", True) + init_contexts = [no_init_weights(_enable=_fast_init)] + init_contexts.append(init_empty_weights()) + + with ContextManagers(init_contexts): + model = model_class(config, *model_args, **kwargs) + + model = ggml_convert_quant(model, qtype, optimize_model, device="meta") + + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] else: - state_dict = load_state_dict(resolved_archive_file) - load(model, state_dict) - del state_dict + import os + import json + with open(os.path.join(pretrained_model_name_or_path, + "load_keys.json"), "r") as json_file: + loaded_data = json.load(json_file) + loaded_state_dict_keys = loaded_data["all_checkpoint_keys"] + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = model_class._load_pretrained_model( + model, + None, + loaded_state_dict_keys, # XXX: rename? + resolved_archive_file, + pretrained_model_name_or_path, + sharded_metadata=sharded_metadata, + _fast_init=_fast_init, + low_cpu_mem_usage=True, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + keep_in_fp32_modules=[], + ) + + # make sure token embedding weights are still tied if needed + model.tie_weights() + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder=subfolder, + **kwargs, + ) + except (OSError, TypeError): + pass + for param in model.parameters(): + param.requires_grad_(False) return model