Support imatrix-guided quantization for NPU CW (#12468)
* init commit * remove print * add interface * fix * fix * fix style
This commit is contained in:
parent
f99f188023
commit
4b6c3160be
5 changed files with 104 additions and 21 deletions
|
|
@ -1018,6 +1018,35 @@ _lib.ggml_quantize_tensor_rtn.argtypes = [
|
||||||
_lib.ggml_quantize_tensor_rtn.restype = ctypes.c_size_t
|
_lib.ggml_quantize_tensor_rtn.restype = ctypes.c_size_t
|
||||||
|
|
||||||
|
|
||||||
|
def ggml_quantize_tensor_rtn_with_weights(
|
||||||
|
src, # type: ctypes.Array[ctypes.c_float] # type: ignore
|
||||||
|
dst: ctypes.c_void_p,
|
||||||
|
scale_ptr, # type: ctypes.Array[ctypes.c_float] # type: ignore
|
||||||
|
qtype: ctypes.c_int,
|
||||||
|
n: ctypes.c_size_t,
|
||||||
|
k: ctypes.c_int,
|
||||||
|
hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore
|
||||||
|
scale_search: ctypes.c_bool,
|
||||||
|
weights, # type: ctypes.Array[ctypes.c_float] # type: ignore
|
||||||
|
) -> int:
|
||||||
|
return _lib.ggml_quantize_tensor_rtn_with_weights(src, dst, scale_ptr, qtype, n, k,
|
||||||
|
hist, scale_search, weights)
|
||||||
|
|
||||||
|
|
||||||
|
_lib.ggml_quantize_tensor_rtn_with_weights.argtypes = [
|
||||||
|
ctypes.POINTER(ctypes.c_float),
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.POINTER(ctypes.c_float),
|
||||||
|
ctypes.c_int,
|
||||||
|
ctypes.c_size_t,
|
||||||
|
ctypes.c_int,
|
||||||
|
ctypes.POINTER(ctypes.c_int64),
|
||||||
|
ctypes.c_bool,
|
||||||
|
ctypes.POINTER(ctypes.c_float),
|
||||||
|
]
|
||||||
|
_lib.ggml_quantize_tensor_rtn_with_weights.restype = ctypes.c_size_t
|
||||||
|
|
||||||
|
|
||||||
def ggml_type_size(qtype: ctypes.c_int) -> int:
|
def ggml_type_size(qtype: ctypes.c_int) -> int:
|
||||||
return _lib.ggml_type_size(qtype)
|
return _lib.ggml_type_size(qtype)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -246,8 +246,17 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||||
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
|
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
|
||||||
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
|
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
|
||||||
scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
|
scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
|
||||||
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
|
if imatrix is None:
|
||||||
k, hist, enable_scale_search)
|
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
|
||||||
|
k, hist, enable_scale_search)
|
||||||
|
else:
|
||||||
|
imatrix = imatrix.data.data_ptr()
|
||||||
|
imatrix = ctypes.cast(imatrix, ctypes.POINTER(ctypes.c_float))
|
||||||
|
ggml.ggml_quantize_tensor_rtn_with_weights(src, dst, scale_ptr,
|
||||||
|
qtype, n,
|
||||||
|
k, hist,
|
||||||
|
enable_scale_search,
|
||||||
|
imatrix)
|
||||||
return dst_tensor, scale.type(torch.float16)
|
return dst_tensor, scale.type(torch.float16)
|
||||||
else:
|
else:
|
||||||
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
|
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from transformers.dynamic_module_utils import get_imports
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||||
from ipex_llm.transformers.utils import logger
|
from ipex_llm.transformers.utils import logger, load_imatrix_data
|
||||||
from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post
|
from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -137,6 +137,12 @@ class _BaseAutoModelClass:
|
||||||
convert_model = kwargs.pop('convert_model', False)
|
convert_model = kwargs.pop('convert_model', False)
|
||||||
save_directory = kwargs.pop('save_directory', None)
|
save_directory = kwargs.pop('save_directory', None)
|
||||||
fuse_layers = kwargs.pop('fuse_layers', None)
|
fuse_layers = kwargs.pop('fuse_layers', None)
|
||||||
|
imatrix_file = kwargs.pop('imatrix_file', None)
|
||||||
|
|
||||||
|
if imatrix_file is not None:
|
||||||
|
imatrix_data = load_imatrix_data(imatrix_file)
|
||||||
|
else:
|
||||||
|
imatrix_data = None
|
||||||
|
|
||||||
invalidInputError(
|
invalidInputError(
|
||||||
quantization_group_size in [0, 32, 64, 128],
|
quantization_group_size in [0, 32, 64, 128],
|
||||||
|
|
@ -205,7 +211,8 @@ class _BaseAutoModelClass:
|
||||||
"transpose_value_cache": transpose_value_cache,
|
"transpose_value_cache": transpose_value_cache,
|
||||||
"convert_model": convert_model,
|
"convert_model": convert_model,
|
||||||
"save_directory": save_directory,
|
"save_directory": save_directory,
|
||||||
"fuse_layers": fuse_layers
|
"fuse_layers": fuse_layers,
|
||||||
|
"imatrix_data": imatrix_data
|
||||||
}
|
}
|
||||||
model = cls.optimize_npu_model(*args, **optimize_kwargs)
|
model = cls.optimize_npu_model(*args, **optimize_kwargs)
|
||||||
else:
|
else:
|
||||||
|
|
@ -213,7 +220,8 @@ class _BaseAutoModelClass:
|
||||||
optimize_llm(model)
|
optimize_llm(model)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
|
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
|
||||||
quantization_group_size, *args, **kwargs)
|
quantization_group_size, imatrix_data=imatrix_data,
|
||||||
|
*args, **kwargs)
|
||||||
if hasattr(model, "llm"):
|
if hasattr(model, "llm"):
|
||||||
create_npu_kernels(model.llm)
|
create_npu_kernels(model.llm)
|
||||||
else:
|
else:
|
||||||
|
|
@ -246,6 +254,7 @@ class _BaseAutoModelClass:
|
||||||
convert_model = kwargs.pop('convert_model', False)
|
convert_model = kwargs.pop('convert_model', False)
|
||||||
save_directory = kwargs.pop('save_directory', None)
|
save_directory = kwargs.pop('save_directory', None)
|
||||||
fuse_layers = kwargs.pop('fuse_layers', None)
|
fuse_layers = kwargs.pop('fuse_layers', None)
|
||||||
|
imatrix_data = kwargs.pop('imatrix_data', None)
|
||||||
|
|
||||||
if hasattr(model, "llm"):
|
if hasattr(model, "llm"):
|
||||||
llm = model.llm
|
llm = model.llm
|
||||||
|
|
@ -258,7 +267,8 @@ class _BaseAutoModelClass:
|
||||||
optimize_llm_pre(model, qtype, mixed_precision,
|
optimize_llm_pre(model, qtype, mixed_precision,
|
||||||
quantization_group_size=quantization_group_size)
|
quantization_group_size=quantization_group_size)
|
||||||
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
|
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
|
||||||
quantization_group_size, *args, **kwargs)
|
quantization_group_size, imatrix_data,
|
||||||
|
*args, **kwargs)
|
||||||
create_npu_kernels(llm)
|
create_npu_kernels(llm)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
logger.info(f"Finish to convert model")
|
logger.info(f"Finish to convert model")
|
||||||
|
|
@ -305,12 +315,12 @@ class _BaseAutoModelClass:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_convert(cls, q_k, optimize_model, device, modules_to_not_convert,
|
def load_convert(cls, q_k, optimize_model, device, modules_to_not_convert,
|
||||||
group_size=0, *arg, **kwarg):
|
group_size=0, imatrix_data=None, *arg, **kwarg):
|
||||||
from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear
|
from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear
|
||||||
|
|
||||||
replace_with_QuantizedLinear(optimize_model, q_k, device=device,
|
replace_with_QuantizedLinear(optimize_model, q_k, device=device,
|
||||||
modules_to_not_convert=modules_to_not_convert,
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
group_size=group_size)
|
group_size=group_size, imatrix=imatrix_data)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_convert_cpu(cls, q_k, optimize_model, device, modules_to_not_convert,
|
def load_convert_cpu(cls, q_k, optimize_model, device, modules_to_not_convert,
|
||||||
|
|
|
||||||
|
|
@ -19,11 +19,11 @@ import os
|
||||||
import torch
|
import torch
|
||||||
import importlib
|
import importlib
|
||||||
from ipex_llm.transformers.npu_models.linear import QuantizedLinear
|
from ipex_llm.transformers.npu_models.linear import QuantizedLinear
|
||||||
import tempfile
|
|
||||||
import time
|
import time
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
from transformers import GenerationConfig, \
|
from transformers import GenerationConfig, \
|
||||||
LogitsProcessorList, StoppingCriteriaList
|
LogitsProcessorList, StoppingCriteriaList
|
||||||
|
from ipex_llm.transformers.utils import module_name_process
|
||||||
|
|
||||||
|
|
||||||
def module_optimization(func) -> torch.nn.Module:
|
def module_optimization(func) -> torch.nn.Module:
|
||||||
|
|
@ -39,7 +39,7 @@ def module_optimization(func) -> torch.nn.Module:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(model: torch.nn.Module, qtype, device, modules_to_not_convert,
|
def wrapper(model: torch.nn.Module, qtype, device, modules_to_not_convert,
|
||||||
group_size=0, *args, **kwargs):
|
group_size=0, imatrix=None, full_name="", *args, **kwargs):
|
||||||
"""Recursively apply the optimization function.
|
"""Recursively apply the optimization function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -49,23 +49,40 @@ def module_optimization(func) -> torch.nn.Module:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
for name, layer in model.named_children():
|
for name, layer in model.named_children():
|
||||||
|
if full_name == "":
|
||||||
|
cur_full_name = name
|
||||||
|
else:
|
||||||
|
cur_full_name = full_name + "." + name
|
||||||
|
cur_imatrix = None
|
||||||
|
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
|
||||||
|
new_module_name, _, cur_module_name, dq_idx = module_name_process(cur_full_name)
|
||||||
|
if imatrix is not None and new_module_name in imatrix:
|
||||||
|
cur_imatrix = imatrix[new_module_name]
|
||||||
|
if cur_imatrix.shape[0] != layer.weight.shape[1]:
|
||||||
|
ws = layer.weight.shape[1]
|
||||||
|
cur_imatrix = cur_imatrix[ws * dq_idx: ws * (dq_idx + 1)]
|
||||||
if name not in modules_to_not_convert:
|
if name not in modules_to_not_convert:
|
||||||
new_layer = func(layer, qtype, device, modules_to_not_convert,
|
new_layer = func(layer, qtype, device, modules_to_not_convert,
|
||||||
group_size=group_size, *args, **kwargs)
|
group_size=group_size, imatrix=cur_imatrix,
|
||||||
|
*args, **kwargs)
|
||||||
if new_layer:
|
if new_layer:
|
||||||
model.add_module(name, new_layer)
|
model.add_module(name, new_layer)
|
||||||
wrapper(new_layer, qtype, device, modules_to_not_convert,
|
wrapper(new_layer, qtype, device, modules_to_not_convert,
|
||||||
group_size=group_size, *args, **kwargs)
|
group_size=group_size, imatrix=imatrix,
|
||||||
|
full_name=cur_full_name,
|
||||||
|
*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
wrapper(layer, qtype, device, modules_to_not_convert,
|
wrapper(layer, qtype, device, modules_to_not_convert,
|
||||||
group_size=group_size, *args, **kwargs)
|
group_size=group_size, imatrix=imatrix,
|
||||||
|
full_name=cur_full_name,
|
||||||
|
*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@module_optimization
|
@module_optimization
|
||||||
def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
|
def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
|
||||||
group_size):
|
group_size, imatrix):
|
||||||
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
|
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
|
||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||||
iqtype = ggml_tensor_qtype[qtype]
|
iqtype = ggml_tensor_qtype[qtype]
|
||||||
|
|
@ -79,7 +96,8 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
|
||||||
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
|
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
|
||||||
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
|
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
|
||||||
iqtype, device=device,
|
iqtype, device=device,
|
||||||
enable_scale_search=enable_scale_search)
|
enable_scale_search=enable_scale_search,
|
||||||
|
imatrix=imatrix)
|
||||||
return QuantizedLinear(qweights, scale, layer.bias,
|
return QuantizedLinear(qweights, scale, layer.bias,
|
||||||
group_size=group_size)
|
group_size=group_size)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -247,6 +247,10 @@ def module_name_process(full_module_name):
|
||||||
else:
|
else:
|
||||||
super_module_name = None
|
super_module_name = None
|
||||||
exp_id = None
|
exp_id = None
|
||||||
|
new_module_name = None
|
||||||
|
layer = None
|
||||||
|
cur_module = None
|
||||||
|
dq_idx = None
|
||||||
if super_module_name == 'block_sparse_moe':
|
if super_module_name == 'block_sparse_moe':
|
||||||
# handle mixtral moe here
|
# handle mixtral moe here
|
||||||
moe_mapping = {"w1": "gate", "w2": "down", "w3": "up"}
|
moe_mapping = {"w1": "gate", "w2": "down", "w3": "up"}
|
||||||
|
|
@ -265,11 +269,24 @@ def module_name_process(full_module_name):
|
||||||
layer = module_name_list[2]
|
layer = module_name_list[2]
|
||||||
cur_module = module_name_list[-1][:-5]
|
cur_module = module_name_list[-1][:-5]
|
||||||
new_module_name = '_'.join([layer, cur_module])
|
new_module_name = '_'.join([layer, cur_module])
|
||||||
|
elif len(module_name_list) == 6 and 'dq' in module_name_list[-1]:
|
||||||
|
# for NPU dq_list linear
|
||||||
|
layer = module_name_list[2]
|
||||||
|
cur_module = module_name_list[-1]
|
||||||
|
try:
|
||||||
|
dq_idx = int(cur_module[-2:])
|
||||||
|
except:
|
||||||
|
dq_idx = int(cur_module[-1:])
|
||||||
|
if cur_module[0] in 'qkvo':
|
||||||
|
cur_module = cur_module[0]
|
||||||
|
elif cur_module[:2] == "up":
|
||||||
|
cur_module = cur_module[:2]
|
||||||
|
elif cur_module[:4] == "gate" or cur_module[:4] == "down":
|
||||||
|
cur_module = cur_module[:4]
|
||||||
|
new_module_name = '_'.join([layer, cur_module])
|
||||||
elif len(module_name_list) == 1:
|
elif len(module_name_list) == 1:
|
||||||
new_module_name = module_name_list[0]
|
new_module_name = module_name_list[0]
|
||||||
layer = None
|
return new_module_name, layer, cur_module, dq_idx
|
||||||
cur_module = None
|
|
||||||
return new_module_name, layer, cur_module
|
|
||||||
|
|
||||||
|
|
||||||
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_config=None):
|
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_config=None):
|
||||||
|
|
@ -283,7 +300,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
|
||||||
if qtype in [ggml_tensor_qtype["gguf_iq2_xxs"], ggml_tensor_qtype["gguf_iq2_xs"],
|
if qtype in [ggml_tensor_qtype["gguf_iq2_xxs"], ggml_tensor_qtype["gguf_iq2_xs"],
|
||||||
ggml_tensor_qtype["gguf_iq1_s"]]:
|
ggml_tensor_qtype["gguf_iq1_s"]]:
|
||||||
# For quantization which needs importance matrix
|
# For quantization which needs importance matrix
|
||||||
new_module_name, layer, cur_module = module_name_process(full_module_name)
|
new_module_name, layer, cur_module, _ = module_name_process(full_module_name)
|
||||||
# custom mixed quantization strategy
|
# custom mixed quantization strategy
|
||||||
if model_type == "mixtral":
|
if model_type == "mixtral":
|
||||||
if cur_module == 'v':
|
if cur_module == 'v':
|
||||||
|
|
@ -312,7 +329,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
|
||||||
if new_module_name == 'lm_head':
|
if new_module_name == 'lm_head':
|
||||||
cur_qtype = ggml_tensor_qtype['sym_int8']
|
cur_qtype = ggml_tensor_qtype['sym_int8']
|
||||||
elif qtype == ggml_tensor_qtype["q2_k"]:
|
elif qtype == ggml_tensor_qtype["q2_k"]:
|
||||||
new_module_name, layer, cur_module = module_name_process(full_module_name)
|
new_module_name, layer, cur_module, _ = module_name_process(full_module_name)
|
||||||
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
|
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
|
||||||
# TODO: q2_k need others k-quants type here
|
# TODO: q2_k need others k-quants type here
|
||||||
cur_qtype = ggml_tensor_qtype['q2_k']
|
cur_qtype = ggml_tensor_qtype['q2_k']
|
||||||
|
|
@ -325,7 +342,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
|
||||||
cur_qtype = ggml_tensor_qtype['sym_int8']
|
cur_qtype = ggml_tensor_qtype['sym_int8']
|
||||||
elif qtype > 100:
|
elif qtype > 100:
|
||||||
# gguf mixed precision
|
# gguf mixed precision
|
||||||
new_module_name, layer, cur_module = module_name_process(full_module_name)
|
new_module_name, layer, cur_module, _ = module_name_process(full_module_name)
|
||||||
num_hidden_layers = getattr(model_config, "num_hidden_layers", None)
|
num_hidden_layers = getattr(model_config, "num_hidden_layers", None)
|
||||||
if qtype in [gguf_mixed_qtype["gguf_q4k_s"], gguf_mixed_qtype["gguf_q4k_m"]] and \
|
if qtype in [gguf_mixed_qtype["gguf_q4k_s"], gguf_mixed_qtype["gguf_q4k_m"]] and \
|
||||||
new_module_name == 'lm_head':
|
new_module_name == 'lm_head':
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue