diff --git a/python/llm/README.md b/python/llm/README.md index 0a1edada..363ea637 100644 --- a/python/llm/README.md +++ b/python/llm/README.md @@ -100,7 +100,12 @@ You may run the models using `transformers`-style API in `bigdl-llm`. output = tokenizer.batch_decode(output_ids) ``` - See the complete example [here](example/transformers/transformers_int4/transformers_int4_pipeline.py). + See the complete example [here](example/transformers/transformers_int4/transformers_int4_pipeline.py). + + Notice: For more quantized precision, you can use another parameter `load_in_low_bit`. `q4_0` and `q4_1` are INT4 quantization, `q5_0` and `q5_1` are INT5 quantization, `q8_0` is INT8 quantization. Like: + ```python + model = AutoModelForCausalLM.from_pretrained('/path/to/model/', load_in_low_bit="q5_0") + ``` - ##### Using native INT4 format diff --git a/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py b/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py index 956c3bfa..f9ad0917 100644 --- a/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py +++ b/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py @@ -955,28 +955,49 @@ _lib.llama_print_system_info.restype = c_char_p # GGML API -def ggml_quantize_q4_0( +def ggml_quantize_tensor( src, # type: ctypes.Array[ctypes.c_float] # type: ignore dst: ctypes.c_void_p, + qtype: ctypes.c_int, n: ctypes.c_int, k: ctypes.c_int, hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore ) -> int: - return _lib.ggml_quantize_q4_0(src, dst, n, k, hist) + return _lib.ggml_quantize_tensor(src, dst, qtype, n, k, hist) -_lib.ggml_quantize_q4_0.argtypes = [ +_lib.ggml_quantize_tensor.argtypes = [ ctypes.POINTER(ctypes.c_float), ctypes.c_void_p, ctypes.c_int, ctypes.c_int, + ctypes.c_int, ctypes.POINTER(ctypes.c_int64), ] -_lib.ggml_quantize_q4_0.restype = ctypes.c_size_t +_lib.ggml_quantize_tensor.restype = ctypes.c_size_t + + +def ggml_type_size(qtype: ctypes.c_int) -> int: + return _lib.ggml_type_size(qtype) + +_lib.ggml_type_size.argtypes = [ + ctypes.c_int, +] +_lib.ggml_type_size.restype = ctypes.c_int + + +def ggml_qk_size(qtype: ctypes.c_int) -> int: + return _lib.ggml_qk_size(qtype) + +_lib.ggml_qk_size.argtypes = [ + ctypes.c_int, +] +_lib.ggml_qk_size.restype = ctypes.c_int def ggml_compute_forward_mul_mat_q_fp32(src_0_ne, # type: ctypes.Array[ctypes.c_int64] src_0_data, # type: ctypes.c_void_p + src_0_qtype, # type: int src_1_ne, # type: ctypes.Array[ctypes.c_int64] src_1_data, # type: ctypes.c_void_p result, # type: ctypes.c_void_p @@ -991,6 +1012,7 @@ def ggml_compute_forward_mul_mat_q_fp32(src_0_ne, # type: ctypes.Array[ctypes.c return _lib.ggml_compute_forward_mul_mat_q_fp32(src_0_ne, src_0_data, + src_0_qtype, src_1_ne, src_1_data, result) @@ -999,6 +1021,7 @@ def ggml_compute_forward_mul_mat_q_fp32(src_0_ne, # type: ctypes.Array[ctypes.c _lib.ggml_compute_forward_mul_mat_q_fp32.argtypes = [ ctypes.POINTER(ctypes.c_int64), ctypes.c_void_p, + ctypes.c_int, ctypes.POINTER(ctypes.c_int64), ctypes.c_void_p, ctypes.c_void_p diff --git a/python/llm/src/bigdl/llm/ggml/quantize.py b/python/llm/src/bigdl/llm/ggml/quantize.py index 952b48e1..c34ef522 100644 --- a/python/llm/src/bigdl/llm/ggml/quantize.py +++ b/python/llm/src/bigdl/llm/ggml/quantize.py @@ -24,6 +24,13 @@ from pathlib import Path dirname, _ = os.path.split(os.path.abspath(__file__)) libs_dirname = os.path.dirname(dirname) +# ggml quantized tensor type, this is different from below file quantized type(_quantize_type) +ggml_tensor_qtype = {"q4_0": 2, + "q4_1": 3, + "q5_0": 6, + "q5_1": 7, + "q8_0": 8} + _llama_quantize_type = {"q4_0": 2, "q4_1": 3, "q5_0": 8, diff --git a/python/llm/src/bigdl/llm/transformers/__init__.py b/python/llm/src/bigdl/llm/transformers/__init__.py index 67bd2474..a22eeaf1 100644 --- a/python/llm/src/bigdl/llm/transformers/__init__.py +++ b/python/llm/src/bigdl/llm/transformers/__init__.py @@ -14,6 +14,6 @@ # limitations under the License. # -from .convert import ggml_convert_int4 +from .convert import ggml_convert_quant from .model import AutoModelForCausalLM, AutoModel from .modelling_bigdl import BigdlNativeForCausalLM diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 7c45b506..642304ed 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -37,12 +37,12 @@ import torch import torch.nn as nn from accelerate import init_empty_weights -from bigdl.llm.transformers.linear_int4 import LinearInt4, ParamsInt4 +from bigdl.llm.transformers.linear_quant import LinearQuant, ParamsQuant import warnings -def _replace_with_int4_linear(model, modules_to_not_convert=None, - current_key_name=None, convert_shape_only=False): +def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None, + current_key_name=None, convert_shape_only=False): has_been_replaced = False for name, module in model.named_children(): if current_key_name is None: @@ -53,19 +53,22 @@ def _replace_with_int4_linear(model, modules_to_not_convert=None, if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): with init_empty_weights(): - new_linear = LinearInt4( + new_linear = LinearQuant( module.in_features, module.out_features, + qtype, module.bias is not None, ) # Copy the weights - paramsint4 = ParamsInt4(data=module.weight.data, - requires_grad=False, - quantized=False, - convert_shape_only=convert_shape_only, - _shape=None).to("cpu") - new_linear._parameters['weight'] = paramsint4 + paramsQuant = ParamsQuant(data=module.weight.data, + requires_grad=False, + quantized=False, + convert_shape_only=convert_shape_only, + _shape=None, + qtype=qtype).to("cpu") + new_linear._parameters['weight'] = paramsQuant + if module.bias is not None: new_linear._parameters['bias'] = nn.Parameter(module.bias.data).to("cpu") @@ -78,18 +81,19 @@ def _replace_with_int4_linear(model, modules_to_not_convert=None, # Remove the last key for recursion if len(list(module.children())) > 0: - _, has_been_replaced = _replace_with_int4_linear( + _, has_been_replaced = _replace_with_quant_linear( module, + qtype, modules_to_not_convert, current_key_name, ) return model, has_been_replaced -def ggml_convert_int4(model, convert_shape_only=False): +def ggml_convert_quant(model, qtype, convert_shape_only=False): modules_to_not_convert = [] # ["lm_head"] - model, has_been_replaced = _replace_with_int4_linear( - model, modules_to_not_convert, None, convert_shape_only=convert_shape_only + model, has_been_replaced = _replace_with_quant_linear( + model, qtype, modules_to_not_convert, None, convert_shape_only=convert_shape_only ) if not has_been_replaced: warnings.warn( diff --git a/python/llm/src/bigdl/llm/transformers/linear_int4.py b/python/llm/src/bigdl/llm/transformers/linear_quant.py similarity index 77% rename from python/llm/src/bigdl/llm/transformers/linear_int4.py rename to python/llm/src/bigdl/llm/transformers/linear_quant.py index 7608e766..7b19c224 100644 --- a/python/llm/src/bigdl/llm/transformers/linear_int4.py +++ b/python/llm/src/bigdl/llm/transformers/linear_quant.py @@ -55,12 +55,10 @@ import bigdl.llm.ggml.model.llama.llama_cpp as ggml import torch import ctypes -QK = 64 # todo read this value from libllama.so -scale_size_in_bytes = 4 -block_size_in_bytes = QK // 2 + scale_size_in_bytes - -def ggml_convert_int4(tensor: torch.Tensor, convert_shape_only=False): +def ggml_convert_quant(tensor: torch.Tensor, qtype: int, convert_shape_only=False): + QK = ggml.ggml_qk_size(qtype) + block_size_in_bytes = ggml.ggml_type_size(qtype) invalidInputError(tensor.dtype == torch.float, "Input tensor must be float32") @@ -80,13 +78,19 @@ def ggml_convert_int4(tensor: torch.Tensor, convert_shape_only=False): hist = (ctypes.c_int64 * 16)() if not convert_shape_only: - ggml.ggml_quantize_q4_0(src, dst, n, k, hist) + ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist) return dst_tensor -class ParamsInt4(torch.nn.Parameter): - def __new__(cls, data=None, requires_grad=True, old_data=None, - quantized=False, _shape=None, convert_shape_only=False): +class ParamsQuant(torch.nn.Parameter): + def __new__(cls, + data=None, + requires_grad=True, + old_data=None, + quantized=False, + _shape=None, + convert_shape_only=False, + qtype=None): if data is None: data = torch.empty(0) @@ -95,14 +99,16 @@ class ParamsInt4(torch.nn.Parameter): self.quantized = quantized self._shape = _shape self.convert_shape_only = convert_shape_only + self.qtype = qtype return self def quantize(self, device): if not self.quantized: w = self.data.contiguous().float() # self.old_data = self.data - w_4bit = ggml_convert_int4(w, convert_shape_only=self.convert_shape_only) - self.data = w_4bit + w_quantized = ggml_convert_quant(w, self.qtype, + convert_shape_only=self.convert_shape_only) + self.data = w_quantized self.quantized = True self._shape = w.shape return self @@ -129,17 +135,21 @@ class ParamsInt4(torch.nn.Parameter): if (device is not None and device.type == "cpu" and self.data.device.type == "cpu"): return self.quantize(device) else: - new_param = ParamsInt4(super().to(device=device, - dtype=dtype, - non_blocking=non_blocking), - requires_grad=self.requires_grad, - quantized=self.quantized, - _shape=self._shape) + new_param = ParamsQuant(super().to(device=device, + dtype=dtype, + non_blocking=non_blocking), + requires_grad=self.requires_grad, + quantized=self.quantized, + _shape=self._shape, + qtype=self.qtype) return new_param -def ggml_matmul_src1_x_src0_t(src0: torch.Tensor, src1: torch.Tensor, src0_shape: torch.Size): +def ggml_matmul_src1_x_src0_t(src0: torch.Tensor, + src1: torch.Tensor, + src0_shape: torch.Size, + src0_qtype: int): if src1.dtype != torch.float32: src1 = src1.float() @@ -165,6 +175,7 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor, src1: torch.Tensor, src0_shape # ctx=ctx_p, src_0_ne=src_0_ne, src_0_data=src_0_data, + src_0_qtype=src0_qtype, src_1_ne=src_1_ne, src_1_data=src_1_data, result=result_ptr, @@ -173,15 +184,16 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor, src1: torch.Tensor, src0_shape return result_t -class LinearInt4(nn.Linear): - def __init__(self, input_features, output_features, bias=True): +class LinearQuant(nn.Linear): + def __init__(self, input_features, output_features, qtype, bias=True): super().__init__(input_features, output_features, bias) - self.weight = ParamsInt4(self.weight.data, requires_grad=False, - old_data=self.weight.data, - quantized=False, _shape=None) + self.weight = ParamsQuant(self.weight.data, requires_grad=False, + old_data=self.weight.data, + quantized=False, _shape=None, qtype=qtype) self.in_len = input_features self.out_len = output_features self.weight_shape = (self.out_len, self.in_len) + self.qtype = qtype def forward(self, x: torch.Tensor): # weights are cast automatically as Int8Params, but the bias has to be cast manually @@ -193,7 +205,7 @@ class LinearInt4(nn.Linear): x0 = self.weight.data - result = ggml_matmul_src1_x_src0_t(x0, x, self.weight_shape) + result = ggml_matmul_src1_x_src0_t(x0, x, self.weight_shape, self.qtype) new_shape = x_shape[:-1] + (self.out_len,) result = result.view(new_shape) diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 3a50b9b9..6ebd4828 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -17,6 +17,8 @@ import transformers from transformers.configuration_utils import PretrainedConfig from .utils import extract_local_archive_file, load_state_dict, load +from bigdl.llm.ggml.quantize import ggml_tensor_qtype +from bigdl.llm.utils.common import invalidInputError class _BaseAutoModelClass: @@ -28,8 +30,17 @@ class _BaseAutoModelClass: *args, **kwargs): load_in_4bit = kwargs.pop("load_in_4bit", False) + qtype = 0 if load_in_4bit: kwargs["low_cpu_mem_usage"] = True + qtype = ggml_tensor_qtype['q4_0'] + load_in_low_bit = kwargs.pop("load_in_low_bit", "").lower() + if load_in_low_bit: + kwargs["low_cpu_mem_usage"] = True + invalidInputError(qtype in ggml_tensor_qtype, + f"Unknown load_in_low_bit value: {qtype}," + f" excepted q4_0, q4_1, q5_0, q5_1, q8_0.") + qtype = ggml_tensor_qtype[load_in_low_bit] subfolder = kwargs.get("subfolder", "") variant = kwargs.get("variant", None) @@ -58,10 +69,10 @@ class _BaseAutoModelClass: # be recorded in AutoConfig, # and this operation is not included in the core Hugging Face infrastructure. if bigdl_transformers_int4: - from .convert import ggml_convert_int4 + from .convert import ggml_convert_quant # We forcefully modify the model's definition # and the tensor shape of int4 weights without quantization. - model = ggml_convert_int4(model, convert_shape_only=True) + model = ggml_convert_quant(model, convert_shape_only=True) # Load the quantized model at last. archive_file = extract_local_archive_file(pretrained_model_name_or_path, subfolder, @@ -69,10 +80,10 @@ class _BaseAutoModelClass: state_dict = load_state_dict(archive_file) load(model, state_dict) del state_dict - elif load_in_4bit: - from .convert import ggml_convert_int4 + elif qtype: + from .convert import ggml_convert_quant model = model.to("cpu") - model = ggml_convert_int4(model) + model = ggml_convert_quant(model, qtype) model.config.update({"bigdl_transformers_int4": True}) return model diff --git a/python/llm/test/convert/test_convert_model.py b/python/llm/test/convert/test_convert_model.py index c8b029ce..3bb96b00 100644 --- a/python/llm/test/convert/test_convert_model.py +++ b/python/llm/test/convert/test_convert_model.py @@ -65,13 +65,20 @@ class TestConvertModel(TestCase): assert os.path.isfile(converted_model_path) def test_transformer_convert_llama(self): - model = AutoModelForCausalLM.from_pretrained(llama_model_path, - load_in_4bit=True) + model = AutoModelForCausalLM.from_pretrained(llama_model_path, load_in_4bit=True) tempdir = tempfile.mkdtemp(dir=output_dir) model.save_pretrained(tempdir) model = AutoModelForCausalLM.from_pretrained(tempdir) assert model is not None + def test_transformer_convert_llama_q5(self): + model = AutoModelForCausalLM.from_pretrained(llama_model_path, + load_in_low_bit="q5_0") + + def test_transformer_convert_llama_q8(self): + model = AutoModelForCausalLM.from_pretrained(llama_model_path, + load_in_low_bit="q8_0") + if __name__ == '__main__': pytest.main([__file__])