Support imatrix-guided quantization for NPU CW (#12468)

* init commit

* remove print

* add interface

* fix

* fix

* fix style
This commit is contained in:
Ruonan Wang 2024-12-01 19:31:26 -08:00 committed by GitHub
parent f99f188023
commit 4b6c3160be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 104 additions and 21 deletions

View file

@ -1018,6 +1018,35 @@ _lib.ggml_quantize_tensor_rtn.argtypes = [
_lib.ggml_quantize_tensor_rtn.restype = ctypes.c_size_t _lib.ggml_quantize_tensor_rtn.restype = ctypes.c_size_t
def ggml_quantize_tensor_rtn_with_weights(
src, # type: ctypes.Array[ctypes.c_float] # type: ignore
dst: ctypes.c_void_p,
scale_ptr, # type: ctypes.Array[ctypes.c_float] # type: ignore
qtype: ctypes.c_int,
n: ctypes.c_size_t,
k: ctypes.c_int,
hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore
scale_search: ctypes.c_bool,
weights, # type: ctypes.Array[ctypes.c_float] # type: ignore
) -> int:
return _lib.ggml_quantize_tensor_rtn_with_weights(src, dst, scale_ptr, qtype, n, k,
hist, scale_search, weights)
_lib.ggml_quantize_tensor_rtn_with_weights.argtypes = [
ctypes.POINTER(ctypes.c_float),
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_float),
ctypes.c_int,
ctypes.c_size_t,
ctypes.c_int,
ctypes.POINTER(ctypes.c_int64),
ctypes.c_bool,
ctypes.POINTER(ctypes.c_float),
]
_lib.ggml_quantize_tensor_rtn_with_weights.restype = ctypes.c_size_t
def ggml_type_size(qtype: ctypes.c_int) -> int: def ggml_type_size(qtype: ctypes.c_int) -> int:
return _lib.ggml_type_size(qtype) return _lib.ggml_type_size(qtype)

View file

@ -246,8 +246,17 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]: if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]: if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float)) scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
if imatrix is None:
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n, ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
k, hist, enable_scale_search) k, hist, enable_scale_search)
else:
imatrix = imatrix.data.data_ptr()
imatrix = ctypes.cast(imatrix, ctypes.POINTER(ctypes.c_float))
ggml.ggml_quantize_tensor_rtn_with_weights(src, dst, scale_ptr,
qtype, n,
k, hist,
enable_scale_search,
imatrix)
return dst_tensor, scale.type(torch.float16) return dst_tensor, scale.type(torch.float16)
else: else:
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search) ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)

View file

@ -26,7 +26,7 @@ from transformers.dynamic_module_utils import get_imports
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from ipex_llm.utils.common.log4Error import invalidInputError from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.utils import logger from ipex_llm.transformers.utils import logger, load_imatrix_data
from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post
@ -137,6 +137,12 @@ class _BaseAutoModelClass:
convert_model = kwargs.pop('convert_model', False) convert_model = kwargs.pop('convert_model', False)
save_directory = kwargs.pop('save_directory', None) save_directory = kwargs.pop('save_directory', None)
fuse_layers = kwargs.pop('fuse_layers', None) fuse_layers = kwargs.pop('fuse_layers', None)
imatrix_file = kwargs.pop('imatrix_file', None)
if imatrix_file is not None:
imatrix_data = load_imatrix_data(imatrix_file)
else:
imatrix_data = None
invalidInputError( invalidInputError(
quantization_group_size in [0, 32, 64, 128], quantization_group_size in [0, 32, 64, 128],
@ -205,7 +211,8 @@ class _BaseAutoModelClass:
"transpose_value_cache": transpose_value_cache, "transpose_value_cache": transpose_value_cache,
"convert_model": convert_model, "convert_model": convert_model,
"save_directory": save_directory, "save_directory": save_directory,
"fuse_layers": fuse_layers "fuse_layers": fuse_layers,
"imatrix_data": imatrix_data
} }
model = cls.optimize_npu_model(*args, **optimize_kwargs) model = cls.optimize_npu_model(*args, **optimize_kwargs)
else: else:
@ -213,7 +220,8 @@ class _BaseAutoModelClass:
optimize_llm(model) optimize_llm(model)
with torch.no_grad(): with torch.no_grad():
cls.load_convert(qtype, model, "cpu", modules_to_not_convert, cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
quantization_group_size, *args, **kwargs) quantization_group_size, imatrix_data=imatrix_data,
*args, **kwargs)
if hasattr(model, "llm"): if hasattr(model, "llm"):
create_npu_kernels(model.llm) create_npu_kernels(model.llm)
else: else:
@ -246,6 +254,7 @@ class _BaseAutoModelClass:
convert_model = kwargs.pop('convert_model', False) convert_model = kwargs.pop('convert_model', False)
save_directory = kwargs.pop('save_directory', None) save_directory = kwargs.pop('save_directory', None)
fuse_layers = kwargs.pop('fuse_layers', None) fuse_layers = kwargs.pop('fuse_layers', None)
imatrix_data = kwargs.pop('imatrix_data', None)
if hasattr(model, "llm"): if hasattr(model, "llm"):
llm = model.llm llm = model.llm
@ -258,7 +267,8 @@ class _BaseAutoModelClass:
optimize_llm_pre(model, qtype, mixed_precision, optimize_llm_pre(model, qtype, mixed_precision,
quantization_group_size=quantization_group_size) quantization_group_size=quantization_group_size)
cls.load_convert(qtype, model, "cpu", modules_to_not_convert, cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
quantization_group_size, *args, **kwargs) quantization_group_size, imatrix_data,
*args, **kwargs)
create_npu_kernels(llm) create_npu_kernels(llm)
model = model.eval() model = model.eval()
logger.info(f"Finish to convert model") logger.info(f"Finish to convert model")
@ -305,12 +315,12 @@ class _BaseAutoModelClass:
@classmethod @classmethod
def load_convert(cls, q_k, optimize_model, device, modules_to_not_convert, def load_convert(cls, q_k, optimize_model, device, modules_to_not_convert,
group_size=0, *arg, **kwarg): group_size=0, imatrix_data=None, *arg, **kwarg):
from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear
replace_with_QuantizedLinear(optimize_model, q_k, device=device, replace_with_QuantizedLinear(optimize_model, q_k, device=device,
modules_to_not_convert=modules_to_not_convert, modules_to_not_convert=modules_to_not_convert,
group_size=group_size) group_size=group_size, imatrix=imatrix_data)
@classmethod @classmethod
def load_convert_cpu(cls, q_k, optimize_model, device, modules_to_not_convert, def load_convert_cpu(cls, q_k, optimize_model, device, modules_to_not_convert,

View file

@ -19,11 +19,11 @@ import os
import torch import torch
import importlib import importlib
from ipex_llm.transformers.npu_models.linear import QuantizedLinear from ipex_llm.transformers.npu_models.linear import QuantizedLinear
import tempfile
import time import time
from typing import Callable, List, Optional from typing import Callable, List, Optional
from transformers import GenerationConfig, \ from transformers import GenerationConfig, \
LogitsProcessorList, StoppingCriteriaList LogitsProcessorList, StoppingCriteriaList
from ipex_llm.transformers.utils import module_name_process
def module_optimization(func) -> torch.nn.Module: def module_optimization(func) -> torch.nn.Module:
@ -39,7 +39,7 @@ def module_optimization(func) -> torch.nn.Module:
""" """
def wrapper(model: torch.nn.Module, qtype, device, modules_to_not_convert, def wrapper(model: torch.nn.Module, qtype, device, modules_to_not_convert,
group_size=0, *args, **kwargs): group_size=0, imatrix=None, full_name="", *args, **kwargs):
"""Recursively apply the optimization function. """Recursively apply the optimization function.
Args: Args:
@ -49,23 +49,40 @@ def module_optimization(func) -> torch.nn.Module:
""" """
for name, layer in model.named_children(): for name, layer in model.named_children():
if full_name == "":
cur_full_name = name
else:
cur_full_name = full_name + "." + name
cur_imatrix = None
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
new_module_name, _, cur_module_name, dq_idx = module_name_process(cur_full_name)
if imatrix is not None and new_module_name in imatrix:
cur_imatrix = imatrix[new_module_name]
if cur_imatrix.shape[0] != layer.weight.shape[1]:
ws = layer.weight.shape[1]
cur_imatrix = cur_imatrix[ws * dq_idx: ws * (dq_idx + 1)]
if name not in modules_to_not_convert: if name not in modules_to_not_convert:
new_layer = func(layer, qtype, device, modules_to_not_convert, new_layer = func(layer, qtype, device, modules_to_not_convert,
group_size=group_size, *args, **kwargs) group_size=group_size, imatrix=cur_imatrix,
*args, **kwargs)
if new_layer: if new_layer:
model.add_module(name, new_layer) model.add_module(name, new_layer)
wrapper(new_layer, qtype, device, modules_to_not_convert, wrapper(new_layer, qtype, device, modules_to_not_convert,
group_size=group_size, *args, **kwargs) group_size=group_size, imatrix=imatrix,
full_name=cur_full_name,
*args, **kwargs)
else: else:
wrapper(layer, qtype, device, modules_to_not_convert, wrapper(layer, qtype, device, modules_to_not_convert,
group_size=group_size, *args, **kwargs) group_size=group_size, imatrix=imatrix,
full_name=cur_full_name,
*args, **kwargs)
return wrapper return wrapper
@module_optimization @module_optimization
def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
group_size): group_size, imatrix):
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
from ipex_llm.ggml.quantize import ggml_tensor_qtype from ipex_llm.ggml.quantize import ggml_tensor_qtype
iqtype = ggml_tensor_qtype[qtype] iqtype = ggml_tensor_qtype[qtype]
@ -79,7 +96,8 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0" enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32), qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
iqtype, device=device, iqtype, device=device,
enable_scale_search=enable_scale_search) enable_scale_search=enable_scale_search,
imatrix=imatrix)
return QuantizedLinear(qweights, scale, layer.bias, return QuantizedLinear(qweights, scale, layer.bias,
group_size=group_size) group_size=group_size)

View file

@ -247,6 +247,10 @@ def module_name_process(full_module_name):
else: else:
super_module_name = None super_module_name = None
exp_id = None exp_id = None
new_module_name = None
layer = None
cur_module = None
dq_idx = None
if super_module_name == 'block_sparse_moe': if super_module_name == 'block_sparse_moe':
# handle mixtral moe here # handle mixtral moe here
moe_mapping = {"w1": "gate", "w2": "down", "w3": "up"} moe_mapping = {"w1": "gate", "w2": "down", "w3": "up"}
@ -265,11 +269,24 @@ def module_name_process(full_module_name):
layer = module_name_list[2] layer = module_name_list[2]
cur_module = module_name_list[-1][:-5] cur_module = module_name_list[-1][:-5]
new_module_name = '_'.join([layer, cur_module]) new_module_name = '_'.join([layer, cur_module])
elif len(module_name_list) == 6 and 'dq' in module_name_list[-1]:
# for NPU dq_list linear
layer = module_name_list[2]
cur_module = module_name_list[-1]
try:
dq_idx = int(cur_module[-2:])
except:
dq_idx = int(cur_module[-1:])
if cur_module[0] in 'qkvo':
cur_module = cur_module[0]
elif cur_module[:2] == "up":
cur_module = cur_module[:2]
elif cur_module[:4] == "gate" or cur_module[:4] == "down":
cur_module = cur_module[:4]
new_module_name = '_'.join([layer, cur_module])
elif len(module_name_list) == 1: elif len(module_name_list) == 1:
new_module_name = module_name_list[0] new_module_name = module_name_list[0]
layer = None return new_module_name, layer, cur_module, dq_idx
cur_module = None
return new_module_name, layer, cur_module
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_config=None): def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_config=None):
@ -283,7 +300,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
if qtype in [ggml_tensor_qtype["gguf_iq2_xxs"], ggml_tensor_qtype["gguf_iq2_xs"], if qtype in [ggml_tensor_qtype["gguf_iq2_xxs"], ggml_tensor_qtype["gguf_iq2_xs"],
ggml_tensor_qtype["gguf_iq1_s"]]: ggml_tensor_qtype["gguf_iq1_s"]]:
# For quantization which needs importance matrix # For quantization which needs importance matrix
new_module_name, layer, cur_module = module_name_process(full_module_name) new_module_name, layer, cur_module, _ = module_name_process(full_module_name)
# custom mixed quantization strategy # custom mixed quantization strategy
if model_type == "mixtral": if model_type == "mixtral":
if cur_module == 'v': if cur_module == 'v':
@ -312,7 +329,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
if new_module_name == 'lm_head': if new_module_name == 'lm_head':
cur_qtype = ggml_tensor_qtype['sym_int8'] cur_qtype = ggml_tensor_qtype['sym_int8']
elif qtype == ggml_tensor_qtype["q2_k"]: elif qtype == ggml_tensor_qtype["q2_k"]:
new_module_name, layer, cur_module = module_name_process(full_module_name) new_module_name, layer, cur_module, _ = module_name_process(full_module_name)
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]): if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
# TODO: q2_k need others k-quants type here # TODO: q2_k need others k-quants type here
cur_qtype = ggml_tensor_qtype['q2_k'] cur_qtype = ggml_tensor_qtype['q2_k']
@ -325,7 +342,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
cur_qtype = ggml_tensor_qtype['sym_int8'] cur_qtype = ggml_tensor_qtype['sym_int8']
elif qtype > 100: elif qtype > 100:
# gguf mixed precision # gguf mixed precision
new_module_name, layer, cur_module = module_name_process(full_module_name) new_module_name, layer, cur_module, _ = module_name_process(full_module_name)
num_hidden_layers = getattr(model_config, "num_hidden_layers", None) num_hidden_layers = getattr(model_config, "num_hidden_layers", None)
if qtype in [gguf_mixed_qtype["gguf_q4k_s"], gguf_mixed_qtype["gguf_q4k_m"]] and \ if qtype in [gguf_mixed_qtype["gguf_q4k_s"], gguf_mixed_qtype["gguf_q4k_m"]] and \
new_module_name == 'lm_head': new_module_name == 'lm_head':