diff --git a/python/llm/src/bigdl/llm/ggml/quantize.py b/python/llm/src/bigdl/llm/ggml/quantize.py index 7023a4bd..579ee913 100644 --- a/python/llm/src/bigdl/llm/ggml/quantize.py +++ b/python/llm/src/bigdl/llm/ggml/quantize.py @@ -31,7 +31,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml "asym_int5": 7, # q5_1 in ggml "sym_int8": 8, # q8_0 in ggml "nf4": 10, - "nf3": 11} + "nf3": 11, + "fp16": 12} _llama_quantize_type = {"q4_0": 2, "q4_1": 3, @@ -71,7 +72,7 @@ def quantize(input_path: str, output_path: str, :param dtype: Quantization method which differs in the resulting model disk size and inference speed. Defalut to `q4_0`. Difference model family may support different types, now the supported list is: - llama : "q4_0", "q4_1", "q4_2" + llama : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0" bloom : "q4_0", "q4_1" gptneox : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0" starcoder : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0" diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index b0bc581d..e0c76233 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -41,12 +41,13 @@ from accelerate import init_empty_weights import warnings import transformers import importlib +from bigdl.llm.ggml.quantize import ggml_tensor_qtype from .utils import logger def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, current_key_name=None, convert_shape_only=False): - from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params + from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear has_been_replaced = False for name, module in model.named_children(): @@ -57,33 +58,55 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, # Check if the current key is not in the `modules_to_not_convert` if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): with init_empty_weights(): - new_linear = LowBitLinear( - module.in_features, - module.out_features, - qtype, - module.bias is not None, - ) + new_linear = None + if qtype != ggml_tensor_qtype["fp16"]: + new_linear = LowBitLinear( + module.in_features, + module.out_features, + qtype, + module.bias is not None, + ) - device_type = module.weight.data.device.type - # Copy the weights - paramsLowBit = FP4Params(data=module.weight.data, - requires_grad=False, - quantized=False, - _shape=None, - convert_shape_only=convert_shape_only, - qtype=qtype).to(device_type) - new_linear._parameters['weight'] = paramsLowBit + device_type = module.weight.data.device.type + # Copy the weights + paramsLowBit = FP4Params(data=module.weight.data, + requires_grad=False, + quantized=False, + _shape=None, + convert_shape_only=convert_shape_only, + qtype=qtype).to(device_type) + new_linear._parameters['weight'] = paramsLowBit + else: + # only support two size now + # may generalize to other sizes + if module.in_features in [4096, 11008]: + # esimd fp16 path + new_linear = FP16Linear( + module.in_features, + module.out_features, + qtype, + module.bias is not None, + ) + device_type = module.weight.data.device.type - if module.bias is not None: - new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ - .to(device_type) + # convert here + m, n = module.weight.data.shape + trans_weight = module.weight.data.reshape(m//16, 16, n) + trans_weight = trans_weight.transpose(1, 2).contiguous() + new_linear._parameters['weight'] = nn.Parameter(trans_weight) - model._modules[name] = new_linear - has_been_replaced = True - # Force requires grad to False to avoid unexpected errors - model._modules[name].requires_grad_(False) + # fp16 may generalize to other sizes later + if new_linear is not None: + if module.bias is not None: + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device_type) - module.weight = None + model._modules[name] = new_linear + has_been_replaced = True + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + + module.weight = None # Remove the last key for recursion if len(list(module.children())) > 0: diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 931a118f..8500f76a 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -378,3 +378,53 @@ class LowBitLinear(nn.Linear): result += self.bias return result.to(x.dtype) + + +class FP16Linear(nn.Linear): + def __init__(self, input_features, output_features, qtype, bias=True, + conver_to_half=True): + super().__init__(input_features, output_features, bias) + self.in_len = input_features + self.out_len = output_features + self.weight_shape = (self.out_len, self.in_len) + self.weight_length = self.out_len * self.in_len + self.qtype = qtype + self.conver_to_half = conver_to_half + + def forward(self, x: torch.Tensor): + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + x_shape = x.shape + x_2d = x.view(-1, x_shape[-1]) + + x0 = self.weight.data + # only work for GPU + invalidInputError(x0.device.type == "xpu", + "FP16 only works for GPU") + try: + import intel_extension_for_pytorch + import linear_fp16_esimd + except ModuleNotFoundError: + invalidInputError(False, + "Please `pip install bigdl_core_xe` first.") + + if x_2d.is_contiguous() is False: + x_2d = x_2d.contiguous() + + if x_2d.shape[0] > 1: + # first token or batch size > 1, re-convert weight + original_weight = self.weight.data.transpose(1, 2) + original_weight = original_weight.reshape(self.out_len, self.in_len) + result = F.linear(x_2d, original_weight.contiguous()) + del original_weight + else: + # rest token, use esimd optimization + result = linear_fp16_esimd.forward(x_2d, self.weight.data) + + new_shape = x_shape[:-1] + (self.out_len,) + result = result.view(new_shape) + if self.bias is not None: + result += self.bias + + return result.to(x.dtype) diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 8f926517..51b2f656 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -60,7 +60,7 @@ class _BaseAutoModelClass: :param load_in_4bit: boolean value, True means load linear's weight to symmetric int 4. Default to be False. :param load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5 - or sym_int8. sym_int4 means symmetric int 4, asym_int4 means + , sym_int8 or fp16. sym_int4 means symmetric int 4, asym_int4 means asymmetric int 4, etc. Relevant low bit optimizations will be applied to the model. :param optimize_model: boolean value, Whether to further optimize the low_bit llm model. @@ -104,8 +104,9 @@ class _BaseAutoModelClass: from .convert import ggml_convert_low_bit invalidInputError(q_k in ggml_tensor_qtype, f"Unknown load_in_low_bit value: {q_k}, expected:" - f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.") + f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8 or fp16.") qtype = ggml_tensor_qtype[q_k] + # In case it needs a second try, # `from_pretrained`` may pop items out in dict # and lead to args missing.