LLM: add mixed precision for lm_head (#10795)
* add mixed_quantization * meet code review * update * fix style * meet review
This commit is contained in:
		
							parent
							
								
									8796401b08
								
							
						
					
					
						commit
						439c834ed3
					
				
					 2 changed files with 25 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -92,6 +92,13 @@ if is_auto_awq_available():
 | 
			
		|||
    from transformers.utils.quantization_config import AwqBackendPackingMethod
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_lm_head(name, model_config, out_features):
 | 
			
		||||
    if name == "lm_head" or getattr(model_config, "vocab_size", None) == out_features:
 | 
			
		||||
        return True
 | 
			
		||||
    else:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_linear_module(module):
 | 
			
		||||
 | 
			
		||||
    in_features = None
 | 
			
		||||
| 
						 | 
				
			
			@ -220,7 +227,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                                 cpu_embedding=False, prefix_name='',
 | 
			
		||||
                                 imatrix_data=None, embedding_qtype=None,
 | 
			
		||||
                                 model_config=None, torch_dtype=torch.float32,
 | 
			
		||||
                                 enable_xetla=False):
 | 
			
		||||
                                 enable_xetla=False,
 | 
			
		||||
                                 mixed_precision=False):
 | 
			
		||||
    from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
 | 
			
		||||
        FP16Linear, BF16Linear
 | 
			
		||||
    from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
 | 
			
		||||
| 
						 | 
				
			
			@ -237,7 +245,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
        if is_linear and not isinstance(module, LowBitLinear):
 | 
			
		||||
            in_features, out_features, mp_group = linear_args
 | 
			
		||||
            optimize_lm_head = False
 | 
			
		||||
            if name == "lm_head":
 | 
			
		||||
            if is_lm_head(name, model_config, out_features):
 | 
			
		||||
                model_type = getattr(model_config, "model_type", None)
 | 
			
		||||
                if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD",
 | 
			
		||||
                                                                      None) == "1":
 | 
			
		||||
| 
						 | 
				
			
			@ -291,6 +299,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                                                                       full_module_name,
 | 
			
		||||
                                                                       imatrix_data,
 | 
			
		||||
                                                                       model_config)
 | 
			
		||||
                    # mixed precison for lm_head
 | 
			
		||||
                    if mixed_precision and is_lm_head(name, model_config, out_features):
 | 
			
		||||
                        if cur_qtype in [ggml_tensor_qtype["sym_int4"],
 | 
			
		||||
                                         ggml_tensor_qtype["asym_int4"]]:
 | 
			
		||||
                            cur_qtype = ggml_tensor_qtype["sym_int8"]
 | 
			
		||||
                    device = module.weight.data.device
 | 
			
		||||
                    # Copy the weights
 | 
			
		||||
                    paramsLowBit = FP4Params(data=module.weight.data,
 | 
			
		||||
| 
						 | 
				
			
			@ -409,6 +422,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                model_config=model_config,
 | 
			
		||||
                torch_dtype=torch_dtype,
 | 
			
		||||
                enable_xetla=enable_xetla,
 | 
			
		||||
                mixed_precision=mixed_precision
 | 
			
		||||
            )
 | 
			
		||||
            has_been_replaced = _flag or has_been_replaced
 | 
			
		||||
    return model, has_been_replaced
 | 
			
		||||
| 
						 | 
				
			
			@ -684,7 +698,8 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
                         lightweight_bmm=False, torch_dtype="auto",
 | 
			
		||||
                         imatrix_data=None,
 | 
			
		||||
                         embedding_qtype=None,
 | 
			
		||||
                         enable_xetla=False):
 | 
			
		||||
                         enable_xetla=False,
 | 
			
		||||
                         mixed_precision=False):
 | 
			
		||||
    logger.info(f"Converting the current model to "
 | 
			
		||||
                f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
 | 
			
		||||
                f"format......")
 | 
			
		||||
| 
						 | 
				
			
			@ -709,6 +724,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
        model_config=getattr(model, "config", None),
 | 
			
		||||
        torch_dtype=torch_dtype,
 | 
			
		||||
        enable_xetla=enable_xetla,
 | 
			
		||||
        mixed_precision=mixed_precision,
 | 
			
		||||
    )
 | 
			
		||||
    if not has_been_replaced:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -140,6 +140,9 @@ class _BaseAutoModelClass:
 | 
			
		|||
            specify the model hub. Default to be ``'huggingface'``.
 | 
			
		||||
        :param embedding_qtype: str value, options are ``'q2_k'`` now. Default to be None.
 | 
			
		||||
            Relevant low bit optimizations will be applied to nn.Embedding layer.
 | 
			
		||||
        :param mixed_precision: boolean value, Whether to use mixed precision quantization.
 | 
			
		||||
            Default to be False. If set to True, we will use sym_int8 for lm_head when
 | 
			
		||||
            load_in_low_bit is sym_int4 or asym_int4.
 | 
			
		||||
        :return: a model instance
 | 
			
		||||
        """
 | 
			
		||||
        pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
 | 
			
		||||
| 
						 | 
				
			
			@ -394,6 +397,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        quant_config = kwargs.pop("quantization_config", None)
 | 
			
		||||
        imatrix_data = kwargs.pop("imatrix_data", None)
 | 
			
		||||
        embedding_qtype = kwargs.pop("embedding_qtype", None)
 | 
			
		||||
        mixed_precision = kwargs.pop("mixed_precision", False)
 | 
			
		||||
        if embedding_qtype is not None:
 | 
			
		||||
            embedding_qtype = ggml_tensor_qtype[embedding_qtype]
 | 
			
		||||
        enable_xetla = kwargs.pop("enable_xetla", False)
 | 
			
		||||
| 
						 | 
				
			
			@ -463,7 +467,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
                                     torch_dtype=kwargs.get("torch_dtype", 'auto'),
 | 
			
		||||
                                     imatrix_data=imatrix_data,
 | 
			
		||||
                                     embedding_qtype=embedding_qtype,
 | 
			
		||||
                                     enable_xetla=enable_xetla,)
 | 
			
		||||
                                     enable_xetla=enable_xetla,
 | 
			
		||||
                                     mixed_precision=mixed_precision)
 | 
			
		||||
        model.config.update({"bigdl_transformers_low_bit": q_k})
 | 
			
		||||
 | 
			
		||||
        # enable tie_word_embeddings for MPT
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue