diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 033ca934..3eca37de 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -191,10 +191,10 @@ def convert_gptq(module, awq=False, llm_awq=False): def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, current_key_name=None, convert_shape_only=False, cpu_embedding=False, prefix_name='', - imatrix_data=None): + imatrix_data=None, embedding_qtype=None): from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \ FP16Linear, BF16Linear - from bigdl.llm.transformers.embedding import LLMEmbedding + from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding has_been_replaced = False for name, module in model.named_children(): @@ -323,6 +323,32 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, sparse=module.sparse, _weight=module.weight.data, ) + elif type(module) == nn.Embedding and embedding_qtype is not None: + q_embedding = LowBitEmbedding( + num_embeddings=module.num_embeddings, + embedding_dim=module.embedding_dim, + padding_idx=module.padding_idx, + max_norm=module.max_norm, + norm_type=module.norm_type, + scale_grad_by_freq=module.scale_grad_by_freq, + sparse=module.sparse, + _weight=module.weight.data, + qtype=embedding_qtype, + ) + device = module.weight.data.device + # Copy the weights + paramsLowBit = FP4Params(data=module.weight.data, + requires_grad=False, + quantized=False, + _shape=None, + convert_shape_only=convert_shape_only, + qtype=embedding_qtype, + in_features=module.embedding_dim).to(device) + q_embedding._parameters['weight'] = paramsLowBit + model._modules[name] = q_embedding + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + module.weight = None # Remove the last key for recursion if len(list(module.children())) > 0: @@ -334,7 +360,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, convert_shape_only, cpu_embedding, prefix_name=prefix_name + '.' + name if prefix_name != '' else name, - imatrix_data=imatrix_data + imatrix_data=imatrix_data, + embedding_qtype=embedding_qtype ) has_been_replaced = _flag or has_been_replaced return model, has_been_replaced @@ -512,7 +539,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, convert_shape_only=False, device="cpu", modules_to_not_convert=None, cpu_embedding=False, lightweight_bmm=False, torch_dtype="auto", - imatrix_data=None): + imatrix_data=None, embedding_qtype=None): logger.info(f"Converting the current model to " f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} " f"format......") @@ -535,6 +562,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, model, qtype, modules_to_not_convert, None, convert_shape_only, cpu_embedding, imatrix_data=imatrix_data, + embedding_qtype=embedding_qtype ) if not has_been_replaced: warnings.warn( diff --git a/python/llm/src/bigdl/llm/transformers/embedding.py b/python/llm/src/bigdl/llm/transformers/embedding.py index ba71b5c0..63cf3b81 100644 --- a/python/llm/src/bigdl/llm/transformers/embedding.py +++ b/python/llm/src/bigdl/llm/transformers/embedding.py @@ -20,6 +20,8 @@ from torch import Tensor from torch.nn import functional as F from torch.nn import Parameter from typing import Optional +from bigdl.llm.transformers.low_bit_linear import FP4Params +from bigdl.llm.utils.common import invalidInputError # To prevent insufficient available memory when moving embedding from XPU back to CPU, @@ -72,3 +74,39 @@ class LLMEmbedding(torch.nn.Embedding): def forward(self, x: Tensor): return super().forward(x.to('cpu')).to(x.device) + + +class LowBitEmbedding(torch.nn.Embedding): + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2., + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + _freeze: bool = False, + device=None, dtype=None, + qtype=None) -> None: + super().__init__(num_embeddings, embedding_dim, padding_idx, + max_norm, norm_type, scale_grad_by_freq, sparse, + _weight, device, dtype) + self.weight = FP4Params(self.weight.data, + requires_grad=False, + quantized=False, _shape=None, qtype=qtype) + self.embedding_dim = embedding_dim + + def forward(self, x: Tensor): + invalidInputError(x.device.type == "xpu", + "`LowBitEmbedding` only supports GPU now.") + try: + import intel_extension_for_pytorch + import linear_q4_0 + except ModuleNotFoundError: + invalidInputError(False, + "Please `pip install bigdl_core_xe` first.") + + result = linear_q4_0.dequantize_rows(x.contiguous(), self.weight.data, + self.weight.qtype, self.embedding_dim) + return result diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 9ac192ce..ec75aee7 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -129,6 +129,8 @@ class _BaseAutoModelClass: added to llama.cpp. :param model_hub: str value, options are ``'huggingface'`` and ``'modelscope'``, specify the model hub. Default to be ``'huggingface'``. + :param embedding_qtype: str value, options are ``'q2_k'`` now. Default to be None. + Relevant low bit optimizations will be applied to nn.Embedding layer. :return: a model instance """ pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \ @@ -159,6 +161,7 @@ class _BaseAutoModelClass: user_quantization_config = kwargs.pop("quantization_config", None) speculative = kwargs.pop("speculative", False) torch_dtype = kwargs.pop("torch_dtype", None) + embedding_qtype = kwargs.pop("embedding_qtype", None) if user_quantization_config is not None and \ "BitsAndBytesConfig" in str(user_quantization_config.__class__): @@ -278,9 +281,15 @@ class _BaseAutoModelClass: if q_k in ["iq2_xxs", "iq2_xs"]: invalidInputError(imatrix_file is not None, "For iq2_xxs and iq2_xs quantization, imatrix is needed.") + cpu_embedding = kwargs.get("cpu_embedding", False) + # for 2bit, default use embedding_quantization + if q_k in ["iq2_xxs", "iq2_xs", "q2_k"] and not cpu_embedding and \ + embedding_qtype is None: + embedding_qtype = "q2_k" if imatrix_file is not None: imatrix_data = load_imatrix_data(imatrix_file) - kwargs['imatrix_data'] = imatrix_data + kwargs["imatrix_data"] = imatrix_data + kwargs["embedding_qtype"] = embedding_qtype model = cls.load_convert(q_k, optimize_model, *args, **kwargs) if speculative: @@ -339,6 +348,9 @@ class _BaseAutoModelClass: lightweight_bmm = kwargs.pop("lightweight_bmm", False) quant_config = kwargs.pop("quantization_config", None) imatrix_data = kwargs.pop("imatrix_data", None) + embedding_qtype = kwargs.pop("embedding_qtype", None) + if embedding_qtype is not None: + embedding_qtype = ggml_tensor_qtype[embedding_qtype] _args = copy.deepcopy(args) _kwargs = copy.deepcopy(kwargs) awq_config = None @@ -400,7 +412,8 @@ class _BaseAutoModelClass: modules_to_not_convert=modules_to_not_convert, cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm, torch_dtype=kwargs.get("torch_dtype", 'auto'), - imatrix_data=imatrix_data) + imatrix_data=imatrix_data, + embedding_qtype=embedding_qtype) model.config.update({"bigdl_transformers_low_bit": q_k}) # enable tie_word_embeddings for MPT @@ -469,6 +482,7 @@ class _BaseAutoModelClass: offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) torch_dtype = kwargs.pop("torch_dtype", "auto") + embedding_qtype = kwargs.pop("embedding_qtype", None) sharded_metadata = None config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path) @@ -488,6 +502,10 @@ class _BaseAutoModelClass: optimize_model = kwargs.pop("optimize_model", True) qtype = ggml_tensor_qtype[bigdl_transformers_low_bit] + if bigdl_transformers_low_bit in ["iq2_xxs", "iq2_xs", "q2_k"] and not cpu_embedding: + embedding_qtype = "q2_k" + if embedding_qtype is not None: + embedding_qtype = ggml_tensor_qtype[embedding_qtype] has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map has_local_code = type(config) in cls.HF_Model._model_mapping.keys() @@ -572,7 +590,8 @@ class _BaseAutoModelClass: quant_device = "meta" if bigdl_lcmu_enabled else "cpu" model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device, modules_to_not_convert=modules_to_not_convert, - cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm) + cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm, + embedding_qtype=embedding_qtype) if is_sharded: loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] diff --git a/python/llm/src/bigdl/llm/transformers/utils.py b/python/llm/src/bigdl/llm/transformers/utils.py index e3751fe7..7f85f03b 100644 --- a/python/llm/src/bigdl/llm/transformers/utils.py +++ b/python/llm/src/bigdl/llm/transformers/utils.py @@ -224,37 +224,49 @@ def load_imatrix_data(imatrix_file): return imatrix_data +def module_name_process(full_module_name): + # full name maybe model.layers.31.self_attn.o_proj + # TODO: how to better aligned and generalize + module_name = full_module_name.split('.') + if len(module_name) == 5: + layer = module_name[2] + cur_module = module_name[-1][:-5] + new_module_name = '_'.join([layer, cur_module]) + elif len(module_name) == 1: + new_module_name = module_name[0] + layer = None + cur_module = None + return new_module_name, layer, cur_module + + def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data): - if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"], - ggml_tensor_qtype["q2_k"]] and imatrix_data is not None: + cur_qtype = qtype + if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]: # For quantization which needs importance matrix - # module name preprocess - # full name maybe model.layers.31.self_attn.o_proj - # TODO: just consider llama/mistral here - # TODO: how to better aligned and generalize - module_name = full_module_name.split('.') - cur_qtype = qtype - if len(module_name) == 5: - layer = module_name[2] - cur_module = module_name[-1][:-5] - new_module_name = '_'.join([layer, cur_module]) - elif len(module_name) == 1: - new_module_name = module_name[0] - layer = None - cur_module = None + new_module_name, layer, cur_module = module_name_process(full_module_name) + # custom mixed quantization strategy + if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]): + cur_qtype = ggml_tensor_qtype['q2_k'] if imatrix_data is not None and new_module_name in imatrix_data: cur_imatrix = imatrix_data[new_module_name] - # custom mixed quantization strategy - if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \ - or new_module_name == 'lm_head': - cur_qtype = ggml_tensor_qtype['sym_int4'] else: + # if no imatrix is available, use fp8 for lm_head cur_imatrix = None - # custom mixed quantization strategy - if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \ - or new_module_name == 'lm_head': - cur_qtype = ggml_tensor_qtype['sym_int4'] + if new_module_name == 'lm_head': + cur_qtype = ggml_tensor_qtype['sym_int8'] return cur_qtype, cur_imatrix + elif qtype == ggml_tensor_qtype["q2_k"]: + new_module_name, layer, cur_module = module_name_process(full_module_name) + if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]): + # TODO: q2_k need others k-quants type here + cur_qtype = ggml_tensor_qtype['q2_k'] + if imatrix_data is not None and new_module_name in imatrix_data: + cur_imatrix = imatrix_data[new_module_name] + else: + # if no imatrix is available, use fp8 for lm_head + cur_imatrix = None + if new_module_name == 'lm_head': + cur_qtype = ggml_tensor_qtype['sym_int8'] else: return qtype, None