Add quantization scale search switch (#11326)
* add scale_search switch * remove llama3 instruct * remove print
This commit is contained in:
parent
8a3247ac71
commit
0af0102e61
3 changed files with 48 additions and 17 deletions
|
|
@ -950,8 +950,9 @@ def ggml_quantize_tensor(
|
||||||
n: ctypes.c_size_t,
|
n: ctypes.c_size_t,
|
||||||
k: ctypes.c_int,
|
k: ctypes.c_int,
|
||||||
hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore
|
hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore
|
||||||
|
scale_search: ctypes.c_bool,
|
||||||
) -> int:
|
) -> int:
|
||||||
return _lib.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
|
return _lib.ggml_quantize_tensor(src, dst, qtype, n, k, hist, scale_search)
|
||||||
|
|
||||||
|
|
||||||
_lib.ggml_quantize_tensor.argtypes = [
|
_lib.ggml_quantize_tensor.argtypes = [
|
||||||
|
|
@ -961,6 +962,7 @@ _lib.ggml_quantize_tensor.argtypes = [
|
||||||
ctypes.c_size_t,
|
ctypes.c_size_t,
|
||||||
ctypes.c_int,
|
ctypes.c_int,
|
||||||
ctypes.POINTER(ctypes.c_int64),
|
ctypes.POINTER(ctypes.c_int64),
|
||||||
|
ctypes.c_bool,
|
||||||
]
|
]
|
||||||
_lib.ggml_quantize_tensor.restype = ctypes.c_size_t
|
_lib.ggml_quantize_tensor.restype = ctypes.c_size_t
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -287,6 +287,12 @@ def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
|
||||||
return ggml_weight, g_id_map
|
return ggml_weight, g_id_map
|
||||||
|
|
||||||
|
|
||||||
|
def use_scale_search(model_config, qtype):
|
||||||
|
if qtype == ggml_tensor_qtype["fp6"] and model_config.model_type not in ["qwen2"]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
convert_shape_only=False,
|
convert_shape_only=False,
|
||||||
cpu_embedding=False, prefix_name='',
|
cpu_embedding=False, prefix_name='',
|
||||||
|
|
@ -295,6 +301,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
enable_xetla=False,
|
enable_xetla=False,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
act_order=False,
|
act_order=False,
|
||||||
|
enable_scale_search=False,
|
||||||
):
|
):
|
||||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
||||||
FP16Linear, BF16Linear
|
FP16Linear, BF16Linear
|
||||||
|
|
@ -333,6 +340,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
enable_xetla=enable_xetla,
|
enable_xetla=enable_xetla,
|
||||||
optimize_lm_head=optimize_lm_head,
|
optimize_lm_head=optimize_lm_head,
|
||||||
act_order=act_order,
|
act_order=act_order,
|
||||||
|
enable_scale_search=enable_scale_search,
|
||||||
)
|
)
|
||||||
device = module.qweight.data.device
|
device = module.qweight.data.device
|
||||||
invalidInputError(device.type != "meta",
|
invalidInputError(device.type != "meta",
|
||||||
|
|
@ -350,7 +358,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
_shape=(out_features, in_features),
|
_shape=(out_features, in_features),
|
||||||
convert_shape_only=convert_shape_only,
|
convert_shape_only=convert_shape_only,
|
||||||
qtype=qtype,
|
qtype=qtype,
|
||||||
enable_xetla=enable_xetla).to(device)
|
enable_xetla=enable_xetla,
|
||||||
|
enable_scale_search=enable_scale_search).to(device)
|
||||||
new_linear._parameters['weight'] = paramsLowBit
|
new_linear._parameters['weight'] = paramsLowBit
|
||||||
if has_bias:
|
if has_bias:
|
||||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
|
|
@ -376,7 +385,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
module.bias is not None,
|
module.bias is not None,
|
||||||
mp_group=mp_group,
|
mp_group=mp_group,
|
||||||
enable_xetla=enable_xetla,
|
enable_xetla=enable_xetla,
|
||||||
optimize_lm_head=optimize_lm_head
|
optimize_lm_head=optimize_lm_head,
|
||||||
|
enable_scale_search=enable_scale_search,
|
||||||
)
|
)
|
||||||
device = module.weight.data.device
|
device = module.weight.data.device
|
||||||
# Copy the weights
|
# Copy the weights
|
||||||
|
|
@ -388,7 +398,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
qtype=cur_qtype,
|
qtype=cur_qtype,
|
||||||
imatrix=cur_imatrix,
|
imatrix=cur_imatrix,
|
||||||
in_features=in_features,
|
in_features=in_features,
|
||||||
enable_xetla=enable_xetla).to(device)
|
enable_xetla=enable_xetla,
|
||||||
|
enable_scale_search=enable_scale_search).to(device)
|
||||||
new_linear._parameters['weight'] = paramsLowBit
|
new_linear._parameters['weight'] = paramsLowBit
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
|
|
@ -498,6 +509,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
enable_xetla=enable_xetla,
|
enable_xetla=enable_xetla,
|
||||||
mixed_precision=mixed_precision,
|
mixed_precision=mixed_precision,
|
||||||
act_order=act_order,
|
act_order=act_order,
|
||||||
|
enable_scale_search=enable_scale_search,
|
||||||
)
|
)
|
||||||
has_been_replaced = _flag or has_been_replaced
|
has_been_replaced = _flag or has_been_replaced
|
||||||
return model, has_been_replaced
|
return model, has_been_replaced
|
||||||
|
|
@ -769,17 +781,22 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
if getattr(model, "quantization_method", None) == "gptq":
|
if getattr(model, "quantization_method", None) == "gptq":
|
||||||
act_order = model.config.quantization_config.desc_act
|
act_order = model.config.quantization_config.desc_act
|
||||||
|
|
||||||
|
model_config = getattr(model, "config", None)
|
||||||
|
|
||||||
|
enable_scale_search = use_scale_search(model_config, qtype)
|
||||||
|
|
||||||
# mixed quantization needs model_config to choose custom quantization strategy
|
# mixed quantization needs model_config to choose custom quantization strategy
|
||||||
model, has_been_replaced = _replace_with_low_bit_linear(
|
model, has_been_replaced = _replace_with_low_bit_linear(
|
||||||
model, qtype, modules_to_not_convert,
|
model, qtype, modules_to_not_convert,
|
||||||
convert_shape_only, cpu_embedding,
|
convert_shape_only, cpu_embedding,
|
||||||
imatrix_data=imatrix_data,
|
imatrix_data=imatrix_data,
|
||||||
embedding_qtype=embedding_qtype,
|
embedding_qtype=embedding_qtype,
|
||||||
model_config=getattr(model, "config", None),
|
model_config=model_config,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
enable_xetla=enable_xetla,
|
enable_xetla=enable_xetla,
|
||||||
mixed_precision=mixed_precision,
|
mixed_precision=mixed_precision,
|
||||||
act_order=act_order,
|
act_order=act_order,
|
||||||
|
enable_scale_search=enable_scale_search,
|
||||||
)
|
)
|
||||||
if not has_been_replaced:
|
if not has_been_replaced:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
||||||
|
|
@ -201,7 +201,8 @@ def get_qk_size(qtype: int):
|
||||||
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||||
device=None, convert_shape_only=False,
|
device=None, convert_shape_only=False,
|
||||||
imatrix: torch.Tensor=None,
|
imatrix: torch.Tensor=None,
|
||||||
in_features: int=None):
|
in_features: int=None,
|
||||||
|
enable_scale_search: bool=False):
|
||||||
QK = ggml.ggml_qk_size(qtype)
|
QK = ggml.ggml_qk_size(qtype)
|
||||||
block_size_in_bytes = ggml.ggml_type_size(qtype)
|
block_size_in_bytes = ggml.ggml_type_size(qtype)
|
||||||
|
|
||||||
|
|
@ -222,7 +223,7 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||||
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
||||||
hist = (ctypes.c_int64 * 16)()
|
hist = (ctypes.c_int64 * 16)()
|
||||||
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
|
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
|
||||||
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
|
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
|
||||||
else:
|
else:
|
||||||
if imatrix is not None:
|
if imatrix is not None:
|
||||||
# quantize with importance matrix
|
# quantize with importance matrix
|
||||||
|
|
@ -357,7 +358,8 @@ class FP4Params(torch.nn.Parameter):
|
||||||
qtype=None,
|
qtype=None,
|
||||||
imatrix=None,
|
imatrix=None,
|
||||||
in_features=None,
|
in_features=None,
|
||||||
enable_xetla=False,):
|
enable_xetla=False,
|
||||||
|
enable_scale_search=False):
|
||||||
if data is None:
|
if data is None:
|
||||||
data = torch.empty(0)
|
data = torch.empty(0)
|
||||||
|
|
||||||
|
|
@ -370,13 +372,15 @@ class FP4Params(torch.nn.Parameter):
|
||||||
self.imatrix = imatrix
|
self.imatrix = imatrix
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.enable_xetla = enable_xetla
|
self.enable_xetla = enable_xetla
|
||||||
|
self.enable_scale_search = enable_scale_search
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def ggml_mse(self, w, ggml_qtype, device):
|
def ggml_mse(self, w, ggml_qtype, device):
|
||||||
from torch.nn.functional import mse_loss
|
from torch.nn.functional import mse_loss
|
||||||
w_quant = ggml_convert_qtype(w, ggml_qtype,
|
w_quant = ggml_convert_qtype(w, ggml_qtype,
|
||||||
device=device,
|
device=device,
|
||||||
convert_shape_only=self.convert_shape_only)
|
convert_shape_only=self.convert_shape_only,
|
||||||
|
enable_scale_search=self.enable_scale_search)
|
||||||
w_dequant = ggml_convert_fp32(w_quant, w.shape,
|
w_dequant = ggml_convert_fp32(w_quant, w.shape,
|
||||||
reduce(mul, w.shape, 1), ggml_qtype)
|
reduce(mul, w.shape, 1), ggml_qtype)
|
||||||
mse = mse_loss(w_dequant, w)
|
mse = mse_loss(w_dequant, w)
|
||||||
|
|
@ -389,7 +393,8 @@ class FP4Params(torch.nn.Parameter):
|
||||||
if device == 'meta':
|
if device == 'meta':
|
||||||
w_quantized = ggml_convert_qtype(w, SYM_INT4,
|
w_quantized = ggml_convert_qtype(w, SYM_INT4,
|
||||||
device=device,
|
device=device,
|
||||||
convert_shape_only=self.convert_shape_only)
|
convert_shape_only=self.convert_shape_only,
|
||||||
|
enable_scale_search=self.enable_scale_search)
|
||||||
# TODO: should load from config, the current implementation doesn't support
|
# TODO: should load from config, the current implementation doesn't support
|
||||||
# save/load
|
# save/load
|
||||||
self.qtype = SYM_INT4
|
self.qtype = SYM_INT4
|
||||||
|
|
@ -406,7 +411,8 @@ class FP4Params(torch.nn.Parameter):
|
||||||
if device == 'meta':
|
if device == 'meta':
|
||||||
w_quantized = ggml_convert_qtype(w, SYM_INT8,
|
w_quantized = ggml_convert_qtype(w, SYM_INT8,
|
||||||
device=device,
|
device=device,
|
||||||
convert_shape_only=self.convert_shape_only)
|
convert_shape_only=self.convert_shape_only,
|
||||||
|
enable_scale_search=self.enable_scale_search)
|
||||||
# TODO: should load from config, the current implementation doesn't support
|
# TODO: should load from config, the current implementation doesn't support
|
||||||
# save/load
|
# save/load
|
||||||
self.qtype = SYM_INT8
|
self.qtype = SYM_INT8
|
||||||
|
|
@ -424,7 +430,8 @@ class FP4Params(torch.nn.Parameter):
|
||||||
device=device,
|
device=device,
|
||||||
convert_shape_only=self.convert_shape_only,
|
convert_shape_only=self.convert_shape_only,
|
||||||
imatrix=self.imatrix,
|
imatrix=self.imatrix,
|
||||||
in_features=self.in_features)
|
in_features=self.in_features,
|
||||||
|
enable_scale_search=self.enable_scale_search)
|
||||||
self.data = w_quantized
|
self.data = w_quantized
|
||||||
self.quantized = True
|
self.quantized = True
|
||||||
self._shape = w.shape
|
self._shape = w.shape
|
||||||
|
|
@ -467,7 +474,8 @@ class FP4Params(torch.nn.Parameter):
|
||||||
quantized=self.quantized,
|
quantized=self.quantized,
|
||||||
_shape=self._shape,
|
_shape=self._shape,
|
||||||
qtype=self.qtype,
|
qtype=self.qtype,
|
||||||
enable_xetla=self.enable_xetla)
|
enable_xetla=self.enable_xetla,
|
||||||
|
enable_scale_search=self.enable_scale_search)
|
||||||
if self.enable_xetla:
|
if self.enable_xetla:
|
||||||
device_type = get_xpu_device_type(new_param.data)
|
device_type = get_xpu_device_type(new_param.data)
|
||||||
invalidInputError(device_type == "pvc",
|
invalidInputError(device_type == "pvc",
|
||||||
|
|
@ -481,7 +489,8 @@ class FP4Params(torch.nn.Parameter):
|
||||||
quantized=self.quantized,
|
quantized=self.quantized,
|
||||||
_shape=self._shape,
|
_shape=self._shape,
|
||||||
qtype=self.qtype,
|
qtype=self.qtype,
|
||||||
enable_xetla=self.enable_xetla)
|
enable_xetla=self.enable_xetla,
|
||||||
|
enable_scale_search=self.enable_scale_search)
|
||||||
if self.enable_xetla:
|
if self.enable_xetla:
|
||||||
ggml_xpu = ipex_llm_xetla_to_ggml_xpu(new_param.data,
|
ggml_xpu = ipex_llm_xetla_to_ggml_xpu(new_param.data,
|
||||||
new_param._shape,
|
new_param._shape,
|
||||||
|
|
@ -500,7 +509,8 @@ class FP4Params(torch.nn.Parameter):
|
||||||
quantized=self.quantized,
|
quantized=self.quantized,
|
||||||
_shape=self._shape,
|
_shape=self._shape,
|
||||||
qtype=self.qtype,
|
qtype=self.qtype,
|
||||||
enable_xetla=self.enable_xetla)
|
enable_xetla=self.enable_xetla,
|
||||||
|
enable_scale_search=self.enable_scale_search)
|
||||||
return new_param
|
return new_param
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -607,12 +617,14 @@ class MatMulLowBitCPU(torch.autograd.Function):
|
||||||
class LowBitLinear(nn.Linear):
|
class LowBitLinear(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, qtype, bias=True,
|
def __init__(self, input_features, output_features, qtype, bias=True,
|
||||||
conver_to_half=True, mp_group=None, enable_xetla=False,
|
conver_to_half=True, mp_group=None, enable_xetla=False,
|
||||||
optimize_lm_head=False, act_order=False):
|
optimize_lm_head=False, act_order=False,
|
||||||
|
enable_scale_search=False):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
self.weight = FP4Params(self.weight.data,
|
self.weight = FP4Params(self.weight.data,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
quantized=False, _shape=None, qtype=qtype,
|
quantized=False, _shape=None, qtype=qtype,
|
||||||
enable_xetla=enable_xetla)
|
enable_xetla=enable_xetla,
|
||||||
|
enable_scale_search=enable_scale_search)
|
||||||
self.in_len = input_features
|
self.in_len = input_features
|
||||||
self.out_len = output_features
|
self.out_len = output_features
|
||||||
self.weight_shape = (self.out_len, self.in_len)
|
self.weight_shape = (self.out_len, self.in_len)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue