Add quantization scale search switch (#11326)

* add scale_search switch

* remove llama3 instruct

* remove print
This commit is contained in:
Yina Chen 2024-06-14 18:46:52 +08:00 committed by GitHub
parent 8a3247ac71
commit 0af0102e61
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 48 additions and 17 deletions

View file

@ -950,8 +950,9 @@ def ggml_quantize_tensor(
n: ctypes.c_size_t, n: ctypes.c_size_t,
k: ctypes.c_int, k: ctypes.c_int,
hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore
scale_search: ctypes.c_bool,
) -> int: ) -> int:
return _lib.ggml_quantize_tensor(src, dst, qtype, n, k, hist) return _lib.ggml_quantize_tensor(src, dst, qtype, n, k, hist, scale_search)
_lib.ggml_quantize_tensor.argtypes = [ _lib.ggml_quantize_tensor.argtypes = [
@ -961,6 +962,7 @@ _lib.ggml_quantize_tensor.argtypes = [
ctypes.c_size_t, ctypes.c_size_t,
ctypes.c_int, ctypes.c_int,
ctypes.POINTER(ctypes.c_int64), ctypes.POINTER(ctypes.c_int64),
ctypes.c_bool,
] ]
_lib.ggml_quantize_tensor.restype = ctypes.c_size_t _lib.ggml_quantize_tensor.restype = ctypes.c_size_t

View file

@ -287,6 +287,12 @@ def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
return ggml_weight, g_id_map return ggml_weight, g_id_map
def use_scale_search(model_config, qtype):
if qtype == ggml_tensor_qtype["fp6"] and model_config.model_type not in ["qwen2"]:
return True
return False
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
convert_shape_only=False, convert_shape_only=False,
cpu_embedding=False, prefix_name='', cpu_embedding=False, prefix_name='',
@ -295,6 +301,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
enable_xetla=False, enable_xetla=False,
mixed_precision=False, mixed_precision=False,
act_order=False, act_order=False,
enable_scale_search=False,
): ):
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \ from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
FP16Linear, BF16Linear FP16Linear, BF16Linear
@ -333,6 +340,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
enable_xetla=enable_xetla, enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head, optimize_lm_head=optimize_lm_head,
act_order=act_order, act_order=act_order,
enable_scale_search=enable_scale_search,
) )
device = module.qweight.data.device device = module.qweight.data.device
invalidInputError(device.type != "meta", invalidInputError(device.type != "meta",
@ -350,7 +358,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
_shape=(out_features, in_features), _shape=(out_features, in_features),
convert_shape_only=convert_shape_only, convert_shape_only=convert_shape_only,
qtype=qtype, qtype=qtype,
enable_xetla=enable_xetla).to(device) enable_xetla=enable_xetla,
enable_scale_search=enable_scale_search).to(device)
new_linear._parameters['weight'] = paramsLowBit new_linear._parameters['weight'] = paramsLowBit
if has_bias: if has_bias:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
@ -376,7 +385,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
module.bias is not None, module.bias is not None,
mp_group=mp_group, mp_group=mp_group,
enable_xetla=enable_xetla, enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head optimize_lm_head=optimize_lm_head,
enable_scale_search=enable_scale_search,
) )
device = module.weight.data.device device = module.weight.data.device
# Copy the weights # Copy the weights
@ -388,7 +398,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
qtype=cur_qtype, qtype=cur_qtype,
imatrix=cur_imatrix, imatrix=cur_imatrix,
in_features=in_features, in_features=in_features,
enable_xetla=enable_xetla).to(device) enable_xetla=enable_xetla,
enable_scale_search=enable_scale_search).to(device)
new_linear._parameters['weight'] = paramsLowBit new_linear._parameters['weight'] = paramsLowBit
if module.bias is not None: if module.bias is not None:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
@ -498,6 +509,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
enable_xetla=enable_xetla, enable_xetla=enable_xetla,
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
act_order=act_order, act_order=act_order,
enable_scale_search=enable_scale_search,
) )
has_been_replaced = _flag or has_been_replaced has_been_replaced = _flag or has_been_replaced
return model, has_been_replaced return model, has_been_replaced
@ -769,17 +781,22 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
if getattr(model, "quantization_method", None) == "gptq": if getattr(model, "quantization_method", None) == "gptq":
act_order = model.config.quantization_config.desc_act act_order = model.config.quantization_config.desc_act
model_config = getattr(model, "config", None)
enable_scale_search = use_scale_search(model_config, qtype)
# mixed quantization needs model_config to choose custom quantization strategy # mixed quantization needs model_config to choose custom quantization strategy
model, has_been_replaced = _replace_with_low_bit_linear( model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert, model, qtype, modules_to_not_convert,
convert_shape_only, cpu_embedding, convert_shape_only, cpu_embedding,
imatrix_data=imatrix_data, imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype, embedding_qtype=embedding_qtype,
model_config=getattr(model, "config", None), model_config=model_config,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
enable_xetla=enable_xetla, enable_xetla=enable_xetla,
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
act_order=act_order, act_order=act_order,
enable_scale_search=enable_scale_search,
) )
if not has_been_replaced: if not has_been_replaced:
warnings.warn( warnings.warn(

View file

@ -201,7 +201,8 @@ def get_qk_size(qtype: int):
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
device=None, convert_shape_only=False, device=None, convert_shape_only=False,
imatrix: torch.Tensor=None, imatrix: torch.Tensor=None,
in_features: int=None): in_features: int=None,
enable_scale_search: bool=False):
QK = ggml.ggml_qk_size(qtype) QK = ggml.ggml_qk_size(qtype)
block_size_in_bytes = ggml.ggml_type_size(qtype) block_size_in_bytes = ggml.ggml_type_size(qtype)
@ -222,7 +223,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
hist = (ctypes.c_int64 * 16)() hist = (ctypes.c_int64 * 16)()
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]: if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist) ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
else: else:
if imatrix is not None: if imatrix is not None:
# quantize with importance matrix # quantize with importance matrix
@ -357,7 +358,8 @@ class FP4Params(torch.nn.Parameter):
qtype=None, qtype=None,
imatrix=None, imatrix=None,
in_features=None, in_features=None,
enable_xetla=False,): enable_xetla=False,
enable_scale_search=False):
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
@ -370,13 +372,15 @@ class FP4Params(torch.nn.Parameter):
self.imatrix = imatrix self.imatrix = imatrix
self.in_features = in_features self.in_features = in_features
self.enable_xetla = enable_xetla self.enable_xetla = enable_xetla
self.enable_scale_search = enable_scale_search
return self return self
def ggml_mse(self, w, ggml_qtype, device): def ggml_mse(self, w, ggml_qtype, device):
from torch.nn.functional import mse_loss from torch.nn.functional import mse_loss
w_quant = ggml_convert_qtype(w, ggml_qtype, w_quant = ggml_convert_qtype(w, ggml_qtype,
device=device, device=device,
convert_shape_only=self.convert_shape_only) convert_shape_only=self.convert_shape_only,
enable_scale_search=self.enable_scale_search)
w_dequant = ggml_convert_fp32(w_quant, w.shape, w_dequant = ggml_convert_fp32(w_quant, w.shape,
reduce(mul, w.shape, 1), ggml_qtype) reduce(mul, w.shape, 1), ggml_qtype)
mse = mse_loss(w_dequant, w) mse = mse_loss(w_dequant, w)
@ -389,7 +393,8 @@ class FP4Params(torch.nn.Parameter):
if device == 'meta': if device == 'meta':
w_quantized = ggml_convert_qtype(w, SYM_INT4, w_quantized = ggml_convert_qtype(w, SYM_INT4,
device=device, device=device,
convert_shape_only=self.convert_shape_only) convert_shape_only=self.convert_shape_only,
enable_scale_search=self.enable_scale_search)
# TODO: should load from config, the current implementation doesn't support # TODO: should load from config, the current implementation doesn't support
# save/load # save/load
self.qtype = SYM_INT4 self.qtype = SYM_INT4
@ -406,7 +411,8 @@ class FP4Params(torch.nn.Parameter):
if device == 'meta': if device == 'meta':
w_quantized = ggml_convert_qtype(w, SYM_INT8, w_quantized = ggml_convert_qtype(w, SYM_INT8,
device=device, device=device,
convert_shape_only=self.convert_shape_only) convert_shape_only=self.convert_shape_only,
enable_scale_search=self.enable_scale_search)
# TODO: should load from config, the current implementation doesn't support # TODO: should load from config, the current implementation doesn't support
# save/load # save/load
self.qtype = SYM_INT8 self.qtype = SYM_INT8
@ -424,7 +430,8 @@ class FP4Params(torch.nn.Parameter):
device=device, device=device,
convert_shape_only=self.convert_shape_only, convert_shape_only=self.convert_shape_only,
imatrix=self.imatrix, imatrix=self.imatrix,
in_features=self.in_features) in_features=self.in_features,
enable_scale_search=self.enable_scale_search)
self.data = w_quantized self.data = w_quantized
self.quantized = True self.quantized = True
self._shape = w.shape self._shape = w.shape
@ -467,7 +474,8 @@ class FP4Params(torch.nn.Parameter):
quantized=self.quantized, quantized=self.quantized,
_shape=self._shape, _shape=self._shape,
qtype=self.qtype, qtype=self.qtype,
enable_xetla=self.enable_xetla) enable_xetla=self.enable_xetla,
enable_scale_search=self.enable_scale_search)
if self.enable_xetla: if self.enable_xetla:
device_type = get_xpu_device_type(new_param.data) device_type = get_xpu_device_type(new_param.data)
invalidInputError(device_type == "pvc", invalidInputError(device_type == "pvc",
@ -481,7 +489,8 @@ class FP4Params(torch.nn.Parameter):
quantized=self.quantized, quantized=self.quantized,
_shape=self._shape, _shape=self._shape,
qtype=self.qtype, qtype=self.qtype,
enable_xetla=self.enable_xetla) enable_xetla=self.enable_xetla,
enable_scale_search=self.enable_scale_search)
if self.enable_xetla: if self.enable_xetla:
ggml_xpu = ipex_llm_xetla_to_ggml_xpu(new_param.data, ggml_xpu = ipex_llm_xetla_to_ggml_xpu(new_param.data,
new_param._shape, new_param._shape,
@ -500,7 +509,8 @@ class FP4Params(torch.nn.Parameter):
quantized=self.quantized, quantized=self.quantized,
_shape=self._shape, _shape=self._shape,
qtype=self.qtype, qtype=self.qtype,
enable_xetla=self.enable_xetla) enable_xetla=self.enable_xetla,
enable_scale_search=self.enable_scale_search)
return new_param return new_param
@ -607,12 +617,14 @@ class MatMulLowBitCPU(torch.autograd.Function):
class LowBitLinear(nn.Linear): class LowBitLinear(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True, def __init__(self, input_features, output_features, qtype, bias=True,
conver_to_half=True, mp_group=None, enable_xetla=False, conver_to_half=True, mp_group=None, enable_xetla=False,
optimize_lm_head=False, act_order=False): optimize_lm_head=False, act_order=False,
enable_scale_search=False):
super().__init__(input_features, output_features, bias) super().__init__(input_features, output_features, bias)
self.weight = FP4Params(self.weight.data, self.weight = FP4Params(self.weight.data,
requires_grad=False, requires_grad=False,
quantized=False, _shape=None, qtype=qtype, quantized=False, _shape=None, qtype=qtype,
enable_xetla=enable_xetla) enable_xetla=enable_xetla,
enable_scale_search=enable_scale_search)
self.in_len = input_features self.in_len = input_features
self.out_len = output_features self.out_len = output_features
self.weight_shape = (self.out_len, self.in_len) self.weight_shape = (self.out_len, self.in_len)