LLM : Support embedding quantization (only q2k now) (#10170)
* basic logic added * basic support * support save&load, update mixed strategy * fix style * use int8 for lm_head * add check for xpu
This commit is contained in:
		
							parent
							
								
									eca69a6022
								
							
						
					
					
						commit
						3288acb8de
					
				
					 4 changed files with 128 additions and 31 deletions
				
			
		| 
						 | 
				
			
			@ -191,10 +191,10 @@ def convert_gptq(module, awq=False, llm_awq=False):
 | 
			
		|||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		||||
                                 current_key_name=None, convert_shape_only=False,
 | 
			
		||||
                                 cpu_embedding=False, prefix_name='',
 | 
			
		||||
                                 imatrix_data=None):
 | 
			
		||||
                                 imatrix_data=None, embedding_qtype=None):
 | 
			
		||||
    from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
 | 
			
		||||
        FP16Linear, BF16Linear
 | 
			
		||||
    from bigdl.llm.transformers.embedding import LLMEmbedding
 | 
			
		||||
    from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
 | 
			
		||||
    has_been_replaced = False
 | 
			
		||||
 | 
			
		||||
    for name, module in model.named_children():
 | 
			
		||||
| 
						 | 
				
			
			@ -323,6 +323,32 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                sparse=module.sparse,
 | 
			
		||||
                _weight=module.weight.data,
 | 
			
		||||
            )
 | 
			
		||||
        elif type(module) == nn.Embedding and embedding_qtype is not None:
 | 
			
		||||
            q_embedding = LowBitEmbedding(
 | 
			
		||||
                num_embeddings=module.num_embeddings,
 | 
			
		||||
                embedding_dim=module.embedding_dim,
 | 
			
		||||
                padding_idx=module.padding_idx,
 | 
			
		||||
                max_norm=module.max_norm,
 | 
			
		||||
                norm_type=module.norm_type,
 | 
			
		||||
                scale_grad_by_freq=module.scale_grad_by_freq,
 | 
			
		||||
                sparse=module.sparse,
 | 
			
		||||
                _weight=module.weight.data,
 | 
			
		||||
                qtype=embedding_qtype,
 | 
			
		||||
            )
 | 
			
		||||
            device = module.weight.data.device
 | 
			
		||||
            # Copy the weights
 | 
			
		||||
            paramsLowBit = FP4Params(data=module.weight.data,
 | 
			
		||||
                                     requires_grad=False,
 | 
			
		||||
                                     quantized=False,
 | 
			
		||||
                                     _shape=None,
 | 
			
		||||
                                     convert_shape_only=convert_shape_only,
 | 
			
		||||
                                     qtype=embedding_qtype,
 | 
			
		||||
                                     in_features=module.embedding_dim).to(device)
 | 
			
		||||
            q_embedding._parameters['weight'] = paramsLowBit
 | 
			
		||||
            model._modules[name] = q_embedding
 | 
			
		||||
            # Force requires grad to False to avoid unexpected errors
 | 
			
		||||
            model._modules[name].requires_grad_(False)
 | 
			
		||||
            module.weight = None
 | 
			
		||||
 | 
			
		||||
        # Remove the last key for recursion
 | 
			
		||||
        if len(list(module.children())) > 0:
 | 
			
		||||
| 
						 | 
				
			
			@ -334,7 +360,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                convert_shape_only,
 | 
			
		||||
                cpu_embedding,
 | 
			
		||||
                prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
 | 
			
		||||
                imatrix_data=imatrix_data
 | 
			
		||||
                imatrix_data=imatrix_data,
 | 
			
		||||
                embedding_qtype=embedding_qtype
 | 
			
		||||
            )
 | 
			
		||||
            has_been_replaced = _flag or has_been_replaced
 | 
			
		||||
    return model, has_been_replaced
 | 
			
		||||
| 
						 | 
				
			
			@ -512,7 +539,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
                         convert_shape_only=False, device="cpu",
 | 
			
		||||
                         modules_to_not_convert=None, cpu_embedding=False,
 | 
			
		||||
                         lightweight_bmm=False, torch_dtype="auto",
 | 
			
		||||
                         imatrix_data=None):
 | 
			
		||||
                         imatrix_data=None, embedding_qtype=None):
 | 
			
		||||
    logger.info(f"Converting the current model to "
 | 
			
		||||
                f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
 | 
			
		||||
                f"format......")
 | 
			
		||||
| 
						 | 
				
			
			@ -535,6 +562,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
        model, qtype, modules_to_not_convert,
 | 
			
		||||
        None, convert_shape_only, cpu_embedding,
 | 
			
		||||
        imatrix_data=imatrix_data,
 | 
			
		||||
        embedding_qtype=embedding_qtype
 | 
			
		||||
    )
 | 
			
		||||
    if not has_been_replaced:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,6 +20,8 @@ from torch import Tensor
 | 
			
		|||
from torch.nn import functional as F
 | 
			
		||||
from torch.nn import Parameter
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from bigdl.llm.transformers.low_bit_linear import FP4Params
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# To prevent insufficient available memory when moving embedding from XPU back to CPU,
 | 
			
		||||
| 
						 | 
				
			
			@ -72,3 +74,39 @@ class LLMEmbedding(torch.nn.Embedding):
 | 
			
		|||
 | 
			
		||||
    def forward(self, x: Tensor):
 | 
			
		||||
        return super().forward(x.to('cpu')).to(x.device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LowBitEmbedding(torch.nn.Embedding):
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 num_embeddings: int,
 | 
			
		||||
                 embedding_dim: int,
 | 
			
		||||
                 padding_idx: Optional[int] = None,
 | 
			
		||||
                 max_norm: Optional[float] = None,
 | 
			
		||||
                 norm_type: float = 2.,
 | 
			
		||||
                 scale_grad_by_freq: bool = False,
 | 
			
		||||
                 sparse: bool = False,
 | 
			
		||||
                 _weight: Optional[Tensor] = None,
 | 
			
		||||
                 _freeze: bool = False,
 | 
			
		||||
                 device=None, dtype=None,
 | 
			
		||||
                 qtype=None) -> None:
 | 
			
		||||
        super().__init__(num_embeddings, embedding_dim, padding_idx,
 | 
			
		||||
                         max_norm, norm_type, scale_grad_by_freq, sparse,
 | 
			
		||||
                         _weight, device, dtype)
 | 
			
		||||
        self.weight = FP4Params(self.weight.data,
 | 
			
		||||
                                requires_grad=False,
 | 
			
		||||
                                quantized=False, _shape=None, qtype=qtype)
 | 
			
		||||
        self.embedding_dim = embedding_dim
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: Tensor):
 | 
			
		||||
        invalidInputError(x.device.type == "xpu",
 | 
			
		||||
                          "`LowBitEmbedding` only supports GPU now.")
 | 
			
		||||
        try:
 | 
			
		||||
            import intel_extension_for_pytorch
 | 
			
		||||
            import linear_q4_0
 | 
			
		||||
        except ModuleNotFoundError:
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              "Please `pip install bigdl_core_xe` first.")
 | 
			
		||||
 | 
			
		||||
        result = linear_q4_0.dequantize_rows(x.contiguous(), self.weight.data,
 | 
			
		||||
                                             self.weight.qtype, self.embedding_dim)
 | 
			
		||||
        return result
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -129,6 +129,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
            added to llama.cpp.
 | 
			
		||||
        :param model_hub: str value, options are ``'huggingface'`` and ``'modelscope'``,
 | 
			
		||||
            specify the model hub. Default to be ``'huggingface'``.
 | 
			
		||||
        :param embedding_qtype: str value, options are ``'q2_k'`` now. Default to be None.
 | 
			
		||||
            Relevant low bit optimizations will be applied to nn.Embedding layer.
 | 
			
		||||
        :return: a model instance
 | 
			
		||||
        """
 | 
			
		||||
        pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
 | 
			
		||||
| 
						 | 
				
			
			@ -159,6 +161,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        user_quantization_config = kwargs.pop("quantization_config", None)
 | 
			
		||||
        speculative = kwargs.pop("speculative", False)
 | 
			
		||||
        torch_dtype = kwargs.pop("torch_dtype", None)
 | 
			
		||||
        embedding_qtype = kwargs.pop("embedding_qtype", None)
 | 
			
		||||
 | 
			
		||||
        if user_quantization_config is not None and \
 | 
			
		||||
                "BitsAndBytesConfig" in str(user_quantization_config.__class__):
 | 
			
		||||
| 
						 | 
				
			
			@ -278,9 +281,15 @@ class _BaseAutoModelClass:
 | 
			
		|||
            if q_k in ["iq2_xxs", "iq2_xs"]:
 | 
			
		||||
                invalidInputError(imatrix_file is not None,
 | 
			
		||||
                                  "For iq2_xxs and iq2_xs quantization, imatrix is needed.")
 | 
			
		||||
            cpu_embedding = kwargs.get("cpu_embedding", False)
 | 
			
		||||
            # for 2bit, default use embedding_quantization
 | 
			
		||||
            if q_k in ["iq2_xxs", "iq2_xs", "q2_k"] and not cpu_embedding and \
 | 
			
		||||
                    embedding_qtype is None:
 | 
			
		||||
                embedding_qtype = "q2_k"
 | 
			
		||||
            if imatrix_file is not None:
 | 
			
		||||
                imatrix_data = load_imatrix_data(imatrix_file)
 | 
			
		||||
                kwargs['imatrix_data'] = imatrix_data
 | 
			
		||||
                kwargs["imatrix_data"] = imatrix_data
 | 
			
		||||
            kwargs["embedding_qtype"] = embedding_qtype
 | 
			
		||||
            model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
            if speculative:
 | 
			
		||||
| 
						 | 
				
			
			@ -339,6 +348,9 @@ class _BaseAutoModelClass:
 | 
			
		|||
        lightweight_bmm = kwargs.pop("lightweight_bmm", False)
 | 
			
		||||
        quant_config = kwargs.pop("quantization_config", None)
 | 
			
		||||
        imatrix_data = kwargs.pop("imatrix_data", None)
 | 
			
		||||
        embedding_qtype = kwargs.pop("embedding_qtype", None)
 | 
			
		||||
        if embedding_qtype is not None:
 | 
			
		||||
            embedding_qtype = ggml_tensor_qtype[embedding_qtype]
 | 
			
		||||
        _args = copy.deepcopy(args)
 | 
			
		||||
        _kwargs = copy.deepcopy(kwargs)
 | 
			
		||||
        awq_config = None
 | 
			
		||||
| 
						 | 
				
			
			@ -400,7 +412,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
                                     modules_to_not_convert=modules_to_not_convert,
 | 
			
		||||
                                     cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
 | 
			
		||||
                                     torch_dtype=kwargs.get("torch_dtype", 'auto'),
 | 
			
		||||
                                     imatrix_data=imatrix_data)
 | 
			
		||||
                                     imatrix_data=imatrix_data,
 | 
			
		||||
                                     embedding_qtype=embedding_qtype)
 | 
			
		||||
        model.config.update({"bigdl_transformers_low_bit": q_k})
 | 
			
		||||
 | 
			
		||||
        # enable tie_word_embeddings for MPT
 | 
			
		||||
| 
						 | 
				
			
			@ -469,6 +482,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        offload_folder = kwargs.pop("offload_folder", None)
 | 
			
		||||
        offload_state_dict = kwargs.pop("offload_state_dict", False)
 | 
			
		||||
        torch_dtype = kwargs.pop("torch_dtype", "auto")
 | 
			
		||||
        embedding_qtype = kwargs.pop("embedding_qtype", None)
 | 
			
		||||
        sharded_metadata = None
 | 
			
		||||
 | 
			
		||||
        config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
 | 
			
		||||
| 
						 | 
				
			
			@ -488,6 +502,10 @@ class _BaseAutoModelClass:
 | 
			
		|||
        optimize_model = kwargs.pop("optimize_model", True)
 | 
			
		||||
 | 
			
		||||
        qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
 | 
			
		||||
        if bigdl_transformers_low_bit in ["iq2_xxs", "iq2_xs", "q2_k"] and not cpu_embedding:
 | 
			
		||||
            embedding_qtype = "q2_k"
 | 
			
		||||
        if embedding_qtype is not None:
 | 
			
		||||
            embedding_qtype = ggml_tensor_qtype[embedding_qtype]
 | 
			
		||||
 | 
			
		||||
        has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map
 | 
			
		||||
        has_local_code = type(config) in cls.HF_Model._model_mapping.keys()
 | 
			
		||||
| 
						 | 
				
			
			@ -572,7 +590,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
        quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
 | 
			
		||||
        model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
 | 
			
		||||
                                     modules_to_not_convert=modules_to_not_convert,
 | 
			
		||||
                                     cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm)
 | 
			
		||||
                                     cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
 | 
			
		||||
                                     embedding_qtype=embedding_qtype)
 | 
			
		||||
 | 
			
		||||
        if is_sharded:
 | 
			
		||||
            loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -224,37 +224,49 @@ def load_imatrix_data(imatrix_file):
 | 
			
		|||
    return imatrix_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def module_name_process(full_module_name):
 | 
			
		||||
    # full name maybe model.layers.31.self_attn.o_proj
 | 
			
		||||
    # TODO: how to better aligned and generalize
 | 
			
		||||
    module_name = full_module_name.split('.')
 | 
			
		||||
    if len(module_name) == 5:
 | 
			
		||||
        layer = module_name[2]
 | 
			
		||||
        cur_module = module_name[-1][:-5]
 | 
			
		||||
        new_module_name = '_'.join([layer, cur_module])
 | 
			
		||||
    elif len(module_name) == 1:
 | 
			
		||||
        new_module_name = module_name[0]
 | 
			
		||||
        layer = None
 | 
			
		||||
        cur_module = None
 | 
			
		||||
    return new_module_name, layer, cur_module
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
 | 
			
		||||
    if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"],
 | 
			
		||||
                 ggml_tensor_qtype["q2_k"]] and imatrix_data is not None:
 | 
			
		||||
    cur_qtype = qtype
 | 
			
		||||
    if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]:
 | 
			
		||||
        # For quantization which needs importance matrix
 | 
			
		||||
        # module name preprocess
 | 
			
		||||
        # full name maybe model.layers.31.self_attn.o_proj
 | 
			
		||||
        # TODO: just consider llama/mistral here
 | 
			
		||||
        # TODO: how to better aligned and generalize
 | 
			
		||||
        module_name = full_module_name.split('.')
 | 
			
		||||
        cur_qtype = qtype
 | 
			
		||||
        if len(module_name) == 5:
 | 
			
		||||
            layer = module_name[2]
 | 
			
		||||
            cur_module = module_name[-1][:-5]
 | 
			
		||||
            new_module_name = '_'.join([layer, cur_module])
 | 
			
		||||
        elif len(module_name) == 1:
 | 
			
		||||
            new_module_name = module_name[0]
 | 
			
		||||
            layer = None
 | 
			
		||||
            cur_module = None
 | 
			
		||||
        new_module_name, layer, cur_module = module_name_process(full_module_name)
 | 
			
		||||
        # custom mixed quantization strategy
 | 
			
		||||
        if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
 | 
			
		||||
            cur_qtype = ggml_tensor_qtype['q2_k']
 | 
			
		||||
        if imatrix_data is not None and new_module_name in imatrix_data:
 | 
			
		||||
            cur_imatrix = imatrix_data[new_module_name]
 | 
			
		||||
            # custom mixed quantization strategy
 | 
			
		||||
            if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \
 | 
			
		||||
                    or new_module_name == 'lm_head':
 | 
			
		||||
                cur_qtype = ggml_tensor_qtype['sym_int4']
 | 
			
		||||
        else:
 | 
			
		||||
            # if no imatrix is available, use fp8 for lm_head
 | 
			
		||||
            cur_imatrix = None
 | 
			
		||||
            # custom mixed quantization strategy
 | 
			
		||||
            if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \
 | 
			
		||||
                    or new_module_name == 'lm_head':
 | 
			
		||||
                cur_qtype = ggml_tensor_qtype['sym_int4']
 | 
			
		||||
            if new_module_name == 'lm_head':
 | 
			
		||||
                cur_qtype = ggml_tensor_qtype['sym_int8']
 | 
			
		||||
        return cur_qtype, cur_imatrix
 | 
			
		||||
    elif qtype == ggml_tensor_qtype["q2_k"]:
 | 
			
		||||
        new_module_name, layer, cur_module = module_name_process(full_module_name)
 | 
			
		||||
        if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
 | 
			
		||||
            # TODO: q2_k need others k-quants type here
 | 
			
		||||
            cur_qtype = ggml_tensor_qtype['q2_k']
 | 
			
		||||
        if imatrix_data is not None and new_module_name in imatrix_data:
 | 
			
		||||
            cur_imatrix = imatrix_data[new_module_name]
 | 
			
		||||
        else:
 | 
			
		||||
            # if no imatrix is available, use fp8 for lm_head
 | 
			
		||||
            cur_imatrix = None
 | 
			
		||||
            if new_module_name == 'lm_head':
 | 
			
		||||
                cur_qtype = ggml_tensor_qtype['sym_int8']
 | 
			
		||||
    else:
 | 
			
		||||
        return qtype, None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue