diff --git a/python/llm/src/bigdl/llm/optimize.py b/python/llm/src/bigdl/llm/optimize.py index efff4266..43abef5f 100644 --- a/python/llm/src/bigdl/llm/optimize.py +++ b/python/llm/src/bigdl/llm/optimize.py @@ -14,11 +14,56 @@ # limitations under the License. # +import torch +import os +import json from .transformers import ggml_convert_low_bit from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError +# Simulate the Hugging Face format +PYTORCH_MODEL_NAME = "pytorch_model.bin" +CONFIG_NAME = "bigdl_config.json" + + +def _save_low_bit(self, save_dir, *args, **kwargs): + invalidInputError(self._bigdl_config.get("bigdl_transformers_low_bit", False), + f"Detected this model is not a low-bit model, please use from_pretrained's" + f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.") + os.makedirs(save_dir, exist_ok=True) + model_path = os.path.join(save_dir, PYTORCH_MODEL_NAME) + torch.save(self.state_dict(), model_path, *args, **kwargs) + with open(os.path.join(save_dir, CONFIG_NAME), "w") as json_file: + json.dump(self._bigdl_config, json_file) + + +def load_low_bit(model, model_path): + invalidInputError(isinstance(model, torch.nn.Module), + "model should be a instance of `torch.nn.Module`.") + invalidInputError(os.path.isdir(model_path), + "model_path should be a valid directory path.") + invalidInputError(os.path.isdir(os.path.join(model_path, CONFIG_NAME)), + "bigdl_config.json should be under your model directory," + "please check your input path.") + with open(os.path.join(model_path, CONFIG_NAME), 'r') as f: + _config = json.load(f) + + low_bit = _config.get("bigdl_transformers_low_bit", None) + invalidInputError(low_bit, + "Detect this model is not a low-bit model, Please use `optimize_model`" + " with low_bit to get a low-bit model , and " + " serialize the model using save_low_bit first.") + + if low_bit: + qtype = ggml_tensor_qtype[low_bit] + model = ggml_convert_low_bit(model, qtype=qtype, convert_shape_only=True) + + state_dict = torch.load(os.path.join(model_path, PYTORCH_MODEL_NAME)) + model.load_state_dict(state_dict=state_dict) + return model + + def optimize_model(model, low_bit='sym_int4', optimize_llm=True): """ A method to optimize any pytorch models. @@ -34,4 +79,10 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True): f"Unknown load_in_low_bit value: {low_bit}, expected:" f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.") qtype = ggml_tensor_qtype[low_bit] - return ggml_convert_low_bit(model, qtype=qtype, optimize_model=optimize_llm) + model = ggml_convert_low_bit(model, qtype=qtype, optimize_model=optimize_llm) + # add save_low_bit to pretrained model dynamically + import types + model._bigdl_config = dict() + model._bigdl_config["bigdl_transformers_low_bit"] = low_bit + model.save_low_bit = types.MethodType(_save_low_bit, model) + return model diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 00156300..daf4588b 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -45,7 +45,7 @@ from .utils import logger def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, - current_key_name=None): + current_key_name=None, convert_shape_only=False): from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params has_been_replaced = False @@ -70,6 +70,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, requires_grad=False, quantized=False, _shape=None, + convert_shape_only=convert_shape_only, qtype=qtype).to(device_type) new_linear._parameters['weight'] = paramsLowBit @@ -91,15 +92,18 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, qtype, modules_to_not_convert, current_key_name, + convert_shape_only, ) has_been_replaced = _flag or has_been_replaced return model, has_been_replaced -def ggml_convert_low_bit(model, qtype, optimize_model=True, device="cpu"): +def ggml_convert_low_bit(model, qtype, optimize_model=True, + convert_shape_only=False, device="cpu"): modules_to_not_convert = [] # ["lm_head"] model, has_been_replaced = _replace_with_low_bit_linear( - model, qtype, modules_to_not_convert, None + model, qtype, modules_to_not_convert, + None, convert_shape_only, ) if not has_been_replaced: warnings.warn( diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index c5f6f790..69bab76e 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -64,7 +64,8 @@ SYM_INT8 = ggml_tensor_qtype["sym_int8"] NF4 = ggml_tensor_qtype["nf4"] -def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, device=None): +def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, + device=None, convert_shape_only=False): QK = ggml.ggml_qk_size(qtype) block_size_in_bytes = ggml.ggml_type_size(qtype) @@ -83,7 +84,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, device=None): dst_tensor = torch.empty(dst_size, dtype=torch.uint8, device=device) - if device != 'meta': + if not convert_shape_only and device != 'meta': dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) hist = (ctypes.c_int64 * 16)() ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist) @@ -162,6 +163,7 @@ class FP4Params(torch.nn.Parameter): requires_grad=False, quantized=False, _shape=None, + convert_shape_only=False, qtype=None): if data is None: data = torch.empty(0) @@ -171,13 +173,15 @@ class FP4Params(torch.nn.Parameter): self.quantized = quantized self._shape = _shape self.qtype = qtype + self.convert_shape_only = convert_shape_only return self def quantize(self, device=None): if not self.quantized: w = self.data.contiguous().float() w_quantized = ggml_convert_qtype(w, self.qtype, - device=device) + device=device, + convert_shape_only=self.convert_shape_only) self.data = w_quantized self.quantized = True self._shape = w.shape