Add quantization scale search switch (#11326)
* add scale_search switch * remove llama3 instruct * remove print
This commit is contained in:
		
							parent
							
								
									8a3247ac71
								
							
						
					
					
						commit
						0af0102e61
					
				
					 3 changed files with 48 additions and 17 deletions
				
			
		| 
						 | 
				
			
			@ -950,8 +950,9 @@ def ggml_quantize_tensor(
 | 
			
		|||
    n: ctypes.c_size_t,
 | 
			
		||||
    k: ctypes.c_int,
 | 
			
		||||
    hist,  # type: ctypes.Array[ctypes.c_int64] # type: ignore
 | 
			
		||||
    scale_search: ctypes.c_bool,
 | 
			
		||||
) -> int:
 | 
			
		||||
    return _lib.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
 | 
			
		||||
    return _lib.ggml_quantize_tensor(src, dst, qtype, n, k, hist, scale_search)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_lib.ggml_quantize_tensor.argtypes = [
 | 
			
		||||
| 
						 | 
				
			
			@ -961,6 +962,7 @@ _lib.ggml_quantize_tensor.argtypes = [
 | 
			
		|||
    ctypes.c_size_t,
 | 
			
		||||
    ctypes.c_int,
 | 
			
		||||
    ctypes.POINTER(ctypes.c_int64),
 | 
			
		||||
    ctypes.c_bool,
 | 
			
		||||
]
 | 
			
		||||
_lib.ggml_quantize_tensor.restype = ctypes.c_size_t
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -287,6 +287,12 @@ def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
 | 
			
		|||
    return ggml_weight, g_id_map
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_scale_search(model_config, qtype):
 | 
			
		||||
    if qtype == ggml_tensor_qtype["fp6"] and model_config.model_type not in ["qwen2"]:
 | 
			
		||||
        return True
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		||||
                                 convert_shape_only=False,
 | 
			
		||||
                                 cpu_embedding=False, prefix_name='',
 | 
			
		||||
| 
						 | 
				
			
			@ -295,6 +301,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                                 enable_xetla=False,
 | 
			
		||||
                                 mixed_precision=False,
 | 
			
		||||
                                 act_order=False,
 | 
			
		||||
                                 enable_scale_search=False,
 | 
			
		||||
                                 ):
 | 
			
		||||
    from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
 | 
			
		||||
        FP16Linear, BF16Linear
 | 
			
		||||
| 
						 | 
				
			
			@ -333,6 +340,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                        enable_xetla=enable_xetla,
 | 
			
		||||
                        optimize_lm_head=optimize_lm_head,
 | 
			
		||||
                        act_order=act_order,
 | 
			
		||||
                        enable_scale_search=enable_scale_search,
 | 
			
		||||
                    )
 | 
			
		||||
                    device = module.qweight.data.device
 | 
			
		||||
                    invalidInputError(device.type != "meta",
 | 
			
		||||
| 
						 | 
				
			
			@ -350,7 +358,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                                             _shape=(out_features, in_features),
 | 
			
		||||
                                             convert_shape_only=convert_shape_only,
 | 
			
		||||
                                             qtype=qtype,
 | 
			
		||||
                                             enable_xetla=enable_xetla).to(device)
 | 
			
		||||
                                             enable_xetla=enable_xetla,
 | 
			
		||||
                                             enable_scale_search=enable_scale_search).to(device)
 | 
			
		||||
                    new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                    if has_bias:
 | 
			
		||||
                        new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
| 
						 | 
				
			
			@ -376,7 +385,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                        module.bias is not None,
 | 
			
		||||
                        mp_group=mp_group,
 | 
			
		||||
                        enable_xetla=enable_xetla,
 | 
			
		||||
                        optimize_lm_head=optimize_lm_head
 | 
			
		||||
                        optimize_lm_head=optimize_lm_head,
 | 
			
		||||
                        enable_scale_search=enable_scale_search,
 | 
			
		||||
                    )
 | 
			
		||||
                    device = module.weight.data.device
 | 
			
		||||
                    # Copy the weights
 | 
			
		||||
| 
						 | 
				
			
			@ -388,7 +398,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                                             qtype=cur_qtype,
 | 
			
		||||
                                             imatrix=cur_imatrix,
 | 
			
		||||
                                             in_features=in_features,
 | 
			
		||||
                                             enable_xetla=enable_xetla).to(device)
 | 
			
		||||
                                             enable_xetla=enable_xetla,
 | 
			
		||||
                                             enable_scale_search=enable_scale_search).to(device)
 | 
			
		||||
                    new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                    if module.bias is not None:
 | 
			
		||||
                        new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
| 
						 | 
				
			
			@ -498,6 +509,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                enable_xetla=enable_xetla,
 | 
			
		||||
                mixed_precision=mixed_precision,
 | 
			
		||||
                act_order=act_order,
 | 
			
		||||
                enable_scale_search=enable_scale_search,
 | 
			
		||||
            )
 | 
			
		||||
            has_been_replaced = _flag or has_been_replaced
 | 
			
		||||
    return model, has_been_replaced
 | 
			
		||||
| 
						 | 
				
			
			@ -769,17 +781,22 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
    if getattr(model, "quantization_method", None) == "gptq":
 | 
			
		||||
        act_order = model.config.quantization_config.desc_act
 | 
			
		||||
 | 
			
		||||
    model_config = getattr(model, "config", None)
 | 
			
		||||
 | 
			
		||||
    enable_scale_search = use_scale_search(model_config, qtype)
 | 
			
		||||
 | 
			
		||||
    # mixed quantization needs model_config to choose custom quantization strategy
 | 
			
		||||
    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
			
		||||
        model, qtype, modules_to_not_convert,
 | 
			
		||||
        convert_shape_only, cpu_embedding,
 | 
			
		||||
        imatrix_data=imatrix_data,
 | 
			
		||||
        embedding_qtype=embedding_qtype,
 | 
			
		||||
        model_config=getattr(model, "config", None),
 | 
			
		||||
        model_config=model_config,
 | 
			
		||||
        torch_dtype=torch_dtype,
 | 
			
		||||
        enable_xetla=enable_xetla,
 | 
			
		||||
        mixed_precision=mixed_precision,
 | 
			
		||||
        act_order=act_order,
 | 
			
		||||
        enable_scale_search=enable_scale_search,
 | 
			
		||||
    )
 | 
			
		||||
    if not has_been_replaced:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -201,7 +201,8 @@ def get_qk_size(qtype: int):
 | 
			
		|||
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
 | 
			
		||||
                       device=None, convert_shape_only=False,
 | 
			
		||||
                       imatrix: torch.Tensor=None,
 | 
			
		||||
                       in_features: int=None):
 | 
			
		||||
                       in_features: int=None,
 | 
			
		||||
                       enable_scale_search: bool=False):
 | 
			
		||||
    QK = ggml.ggml_qk_size(qtype)
 | 
			
		||||
    block_size_in_bytes = ggml.ggml_type_size(qtype)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -222,7 +223,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
 | 
			
		|||
        dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
 | 
			
		||||
        hist = (ctypes.c_int64 * 16)()
 | 
			
		||||
        if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
 | 
			
		||||
            ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
 | 
			
		||||
            ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
 | 
			
		||||
        else:
 | 
			
		||||
            if imatrix is not None:
 | 
			
		||||
                # quantize with importance matrix
 | 
			
		||||
| 
						 | 
				
			
			@ -357,7 +358,8 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
                qtype=None,
 | 
			
		||||
                imatrix=None,
 | 
			
		||||
                in_features=None,
 | 
			
		||||
                enable_xetla=False,):
 | 
			
		||||
                enable_xetla=False,
 | 
			
		||||
                enable_scale_search=False):
 | 
			
		||||
        if data is None:
 | 
			
		||||
            data = torch.empty(0)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -370,13 +372,15 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
        self.imatrix = imatrix
 | 
			
		||||
        self.in_features = in_features
 | 
			
		||||
        self.enable_xetla = enable_xetla
 | 
			
		||||
        self.enable_scale_search = enable_scale_search
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def ggml_mse(self, w, ggml_qtype, device):
 | 
			
		||||
        from torch.nn.functional import mse_loss
 | 
			
		||||
        w_quant = ggml_convert_qtype(w, ggml_qtype,
 | 
			
		||||
                                     device=device,
 | 
			
		||||
                                     convert_shape_only=self.convert_shape_only)
 | 
			
		||||
                                     convert_shape_only=self.convert_shape_only,
 | 
			
		||||
                                     enable_scale_search=self.enable_scale_search)
 | 
			
		||||
        w_dequant = ggml_convert_fp32(w_quant, w.shape,
 | 
			
		||||
                                      reduce(mul, w.shape, 1), ggml_qtype)
 | 
			
		||||
        mse = mse_loss(w_dequant, w)
 | 
			
		||||
| 
						 | 
				
			
			@ -389,7 +393,8 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
                if device == 'meta':
 | 
			
		||||
                    w_quantized = ggml_convert_qtype(w, SYM_INT4,
 | 
			
		||||
                                                     device=device,
 | 
			
		||||
                                                     convert_shape_only=self.convert_shape_only)
 | 
			
		||||
                                                     convert_shape_only=self.convert_shape_only,
 | 
			
		||||
                                                     enable_scale_search=self.enable_scale_search)
 | 
			
		||||
                    # TODO: should load from config, the current implementation doesn't support
 | 
			
		||||
                    # save/load
 | 
			
		||||
                    self.qtype = SYM_INT4
 | 
			
		||||
| 
						 | 
				
			
			@ -406,7 +411,8 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
                if device == 'meta':
 | 
			
		||||
                    w_quantized = ggml_convert_qtype(w, SYM_INT8,
 | 
			
		||||
                                                     device=device,
 | 
			
		||||
                                                     convert_shape_only=self.convert_shape_only)
 | 
			
		||||
                                                     convert_shape_only=self.convert_shape_only,
 | 
			
		||||
                                                     enable_scale_search=self.enable_scale_search)
 | 
			
		||||
                    # TODO: should load from config, the current implementation doesn't support
 | 
			
		||||
                    # save/load
 | 
			
		||||
                    self.qtype = SYM_INT8
 | 
			
		||||
| 
						 | 
				
			
			@ -424,7 +430,8 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
                                                 device=device,
 | 
			
		||||
                                                 convert_shape_only=self.convert_shape_only,
 | 
			
		||||
                                                 imatrix=self.imatrix,
 | 
			
		||||
                                                 in_features=self.in_features)
 | 
			
		||||
                                                 in_features=self.in_features,
 | 
			
		||||
                                                 enable_scale_search=self.enable_scale_search)
 | 
			
		||||
                self.data = w_quantized
 | 
			
		||||
            self.quantized = True
 | 
			
		||||
            self._shape = w.shape
 | 
			
		||||
| 
						 | 
				
			
			@ -467,7 +474,8 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
                                  quantized=self.quantized,
 | 
			
		||||
                                  _shape=self._shape,
 | 
			
		||||
                                  qtype=self.qtype,
 | 
			
		||||
                                  enable_xetla=self.enable_xetla)
 | 
			
		||||
                                  enable_xetla=self.enable_xetla,
 | 
			
		||||
                                  enable_scale_search=self.enable_scale_search)
 | 
			
		||||
            if self.enable_xetla:
 | 
			
		||||
                device_type = get_xpu_device_type(new_param.data)
 | 
			
		||||
                invalidInputError(device_type == "pvc",
 | 
			
		||||
| 
						 | 
				
			
			@ -481,7 +489,8 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
                                  quantized=self.quantized,
 | 
			
		||||
                                  _shape=self._shape,
 | 
			
		||||
                                  qtype=self.qtype,
 | 
			
		||||
                                  enable_xetla=self.enable_xetla)
 | 
			
		||||
                                  enable_xetla=self.enable_xetla,
 | 
			
		||||
                                  enable_scale_search=self.enable_scale_search)
 | 
			
		||||
            if self.enable_xetla:
 | 
			
		||||
                ggml_xpu = ipex_llm_xetla_to_ggml_xpu(new_param.data,
 | 
			
		||||
                                                      new_param._shape,
 | 
			
		||||
| 
						 | 
				
			
			@ -500,7 +509,8 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
                                  quantized=self.quantized,
 | 
			
		||||
                                  _shape=self._shape,
 | 
			
		||||
                                  qtype=self.qtype,
 | 
			
		||||
                                  enable_xetla=self.enable_xetla)
 | 
			
		||||
                                  enable_xetla=self.enable_xetla,
 | 
			
		||||
                                  enable_scale_search=self.enable_scale_search)
 | 
			
		||||
            return new_param
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -607,12 +617,14 @@ class MatMulLowBitCPU(torch.autograd.Function):
 | 
			
		|||
class LowBitLinear(nn.Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, qtype, bias=True,
 | 
			
		||||
                 conver_to_half=True, mp_group=None, enable_xetla=False,
 | 
			
		||||
                 optimize_lm_head=False, act_order=False):
 | 
			
		||||
                 optimize_lm_head=False, act_order=False,
 | 
			
		||||
                 enable_scale_search=False):
 | 
			
		||||
        super().__init__(input_features, output_features, bias)
 | 
			
		||||
        self.weight = FP4Params(self.weight.data,
 | 
			
		||||
                                requires_grad=False,
 | 
			
		||||
                                quantized=False, _shape=None, qtype=qtype,
 | 
			
		||||
                                enable_xetla=enable_xetla)
 | 
			
		||||
                                enable_xetla=enable_xetla,
 | 
			
		||||
                                enable_scale_search=enable_scale_search)
 | 
			
		||||
        self.in_len = input_features
 | 
			
		||||
        self.out_len = output_features
 | 
			
		||||
        self.weight_shape = (self.out_len, self.in_len)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue