diff --git a/python/llm/src/ipex_llm/ggml/model/llama/llama_cpp.py b/python/llm/src/ipex_llm/ggml/model/llama/llama_cpp.py index bea2fef5..abba2767 100644 --- a/python/llm/src/ipex_llm/ggml/model/llama/llama_cpp.py +++ b/python/llm/src/ipex_llm/ggml/model/llama/llama_cpp.py @@ -950,8 +950,9 @@ def ggml_quantize_tensor( n: ctypes.c_size_t, k: ctypes.c_int, hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore + scale_search: ctypes.c_bool, ) -> int: - return _lib.ggml_quantize_tensor(src, dst, qtype, n, k, hist) + return _lib.ggml_quantize_tensor(src, dst, qtype, n, k, hist, scale_search) _lib.ggml_quantize_tensor.argtypes = [ @@ -961,6 +962,7 @@ _lib.ggml_quantize_tensor.argtypes = [ ctypes.c_size_t, ctypes.c_int, ctypes.POINTER(ctypes.c_int64), + ctypes.c_bool, ] _lib.ggml_quantize_tensor.restype = ctypes.c_size_t diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 746e2cac..86a4c258 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -287,6 +287,12 @@ def convert_gptq(module, awq=False, llm_awq=False, act_order=False): return ggml_weight, g_id_map +def use_scale_search(model_config, qtype): + if qtype == ggml_tensor_qtype["fp6"] and model_config.model_type not in ["qwen2"]: + return True + return False + + def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, convert_shape_only=False, cpu_embedding=False, prefix_name='', @@ -295,6 +301,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, enable_xetla=False, mixed_precision=False, act_order=False, + enable_scale_search=False, ): from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \ FP16Linear, BF16Linear @@ -333,6 +340,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, enable_xetla=enable_xetla, optimize_lm_head=optimize_lm_head, act_order=act_order, + enable_scale_search=enable_scale_search, ) device = module.qweight.data.device invalidInputError(device.type != "meta", @@ -350,7 +358,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, _shape=(out_features, in_features), convert_shape_only=convert_shape_only, qtype=qtype, - enable_xetla=enable_xetla).to(device) + enable_xetla=enable_xetla, + enable_scale_search=enable_scale_search).to(device) new_linear._parameters['weight'] = paramsLowBit if has_bias: new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ @@ -376,7 +385,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, module.bias is not None, mp_group=mp_group, enable_xetla=enable_xetla, - optimize_lm_head=optimize_lm_head + optimize_lm_head=optimize_lm_head, + enable_scale_search=enable_scale_search, ) device = module.weight.data.device # Copy the weights @@ -388,7 +398,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, qtype=cur_qtype, imatrix=cur_imatrix, in_features=in_features, - enable_xetla=enable_xetla).to(device) + enable_xetla=enable_xetla, + enable_scale_search=enable_scale_search).to(device) new_linear._parameters['weight'] = paramsLowBit if module.bias is not None: new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ @@ -498,6 +509,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, enable_xetla=enable_xetla, mixed_precision=mixed_precision, act_order=act_order, + enable_scale_search=enable_scale_search, ) has_been_replaced = _flag or has_been_replaced return model, has_been_replaced @@ -769,17 +781,22 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, if getattr(model, "quantization_method", None) == "gptq": act_order = model.config.quantization_config.desc_act + model_config = getattr(model, "config", None) + + enable_scale_search = use_scale_search(model_config, qtype) + # mixed quantization needs model_config to choose custom quantization strategy model, has_been_replaced = _replace_with_low_bit_linear( model, qtype, modules_to_not_convert, convert_shape_only, cpu_embedding, imatrix_data=imatrix_data, embedding_qtype=embedding_qtype, - model_config=getattr(model, "config", None), + model_config=model_config, torch_dtype=torch_dtype, enable_xetla=enable_xetla, mixed_precision=mixed_precision, act_order=act_order, + enable_scale_search=enable_scale_search, ) if not has_been_replaced: warnings.warn( diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 37a9780f..1038688e 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -201,7 +201,8 @@ def get_qk_size(qtype: int): def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, device=None, convert_shape_only=False, imatrix: torch.Tensor=None, - in_features: int=None): + in_features: int=None, + enable_scale_search: bool=False): QK = ggml.ggml_qk_size(qtype) block_size_in_bytes = ggml.ggml_type_size(qtype) @@ -222,7 +223,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) hist = (ctypes.c_int64 * 16)() if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]: - ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist) + ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search) else: if imatrix is not None: # quantize with importance matrix @@ -357,7 +358,8 @@ class FP4Params(torch.nn.Parameter): qtype=None, imatrix=None, in_features=None, - enable_xetla=False,): + enable_xetla=False, + enable_scale_search=False): if data is None: data = torch.empty(0) @@ -370,13 +372,15 @@ class FP4Params(torch.nn.Parameter): self.imatrix = imatrix self.in_features = in_features self.enable_xetla = enable_xetla + self.enable_scale_search = enable_scale_search return self def ggml_mse(self, w, ggml_qtype, device): from torch.nn.functional import mse_loss w_quant = ggml_convert_qtype(w, ggml_qtype, device=device, - convert_shape_only=self.convert_shape_only) + convert_shape_only=self.convert_shape_only, + enable_scale_search=self.enable_scale_search) w_dequant = ggml_convert_fp32(w_quant, w.shape, reduce(mul, w.shape, 1), ggml_qtype) mse = mse_loss(w_dequant, w) @@ -389,7 +393,8 @@ class FP4Params(torch.nn.Parameter): if device == 'meta': w_quantized = ggml_convert_qtype(w, SYM_INT4, device=device, - convert_shape_only=self.convert_shape_only) + convert_shape_only=self.convert_shape_only, + enable_scale_search=self.enable_scale_search) # TODO: should load from config, the current implementation doesn't support # save/load self.qtype = SYM_INT4 @@ -406,7 +411,8 @@ class FP4Params(torch.nn.Parameter): if device == 'meta': w_quantized = ggml_convert_qtype(w, SYM_INT8, device=device, - convert_shape_only=self.convert_shape_only) + convert_shape_only=self.convert_shape_only, + enable_scale_search=self.enable_scale_search) # TODO: should load from config, the current implementation doesn't support # save/load self.qtype = SYM_INT8 @@ -424,7 +430,8 @@ class FP4Params(torch.nn.Parameter): device=device, convert_shape_only=self.convert_shape_only, imatrix=self.imatrix, - in_features=self.in_features) + in_features=self.in_features, + enable_scale_search=self.enable_scale_search) self.data = w_quantized self.quantized = True self._shape = w.shape @@ -467,7 +474,8 @@ class FP4Params(torch.nn.Parameter): quantized=self.quantized, _shape=self._shape, qtype=self.qtype, - enable_xetla=self.enable_xetla) + enable_xetla=self.enable_xetla, + enable_scale_search=self.enable_scale_search) if self.enable_xetla: device_type = get_xpu_device_type(new_param.data) invalidInputError(device_type == "pvc", @@ -481,7 +489,8 @@ class FP4Params(torch.nn.Parameter): quantized=self.quantized, _shape=self._shape, qtype=self.qtype, - enable_xetla=self.enable_xetla) + enable_xetla=self.enable_xetla, + enable_scale_search=self.enable_scale_search) if self.enable_xetla: ggml_xpu = ipex_llm_xetla_to_ggml_xpu(new_param.data, new_param._shape, @@ -500,7 +509,8 @@ class FP4Params(torch.nn.Parameter): quantized=self.quantized, _shape=self._shape, qtype=self.qtype, - enable_xetla=self.enable_xetla) + enable_xetla=self.enable_xetla, + enable_scale_search=self.enable_scale_search) return new_param @@ -607,12 +617,14 @@ class MatMulLowBitCPU(torch.autograd.Function): class LowBitLinear(nn.Linear): def __init__(self, input_features, output_features, qtype, bias=True, conver_to_half=True, mp_group=None, enable_xetla=False, - optimize_lm_head=False, act_order=False): + optimize_lm_head=False, act_order=False, + enable_scale_search=False): super().__init__(input_features, output_features, bias) self.weight = FP4Params(self.weight.data, requires_grad=False, quantized=False, _shape=None, qtype=qtype, - enable_xetla=enable_xetla) + enable_xetla=enable_xetla, + enable_scale_search=enable_scale_search) self.in_len = input_features self.out_len = output_features self.weight_shape = (self.out_len, self.in_len)