Support imatrix-guided quantization for NPU CW (#12468)
* init commit * remove print * add interface * fix * fix * fix style
This commit is contained in:
		
							parent
							
								
									f99f188023
								
							
						
					
					
						commit
						4b6c3160be
					
				
					 5 changed files with 104 additions and 21 deletions
				
			
		| 
						 | 
				
			
			@ -1018,6 +1018,35 @@ _lib.ggml_quantize_tensor_rtn.argtypes = [
 | 
			
		|||
_lib.ggml_quantize_tensor_rtn.restype = ctypes.c_size_t
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ggml_quantize_tensor_rtn_with_weights(
 | 
			
		||||
    src,  # type: ctypes.Array[ctypes.c_float] # type: ignore
 | 
			
		||||
    dst: ctypes.c_void_p,
 | 
			
		||||
    scale_ptr,  # type: ctypes.Array[ctypes.c_float] # type: ignore
 | 
			
		||||
    qtype: ctypes.c_int,
 | 
			
		||||
    n: ctypes.c_size_t,
 | 
			
		||||
    k: ctypes.c_int,
 | 
			
		||||
    hist,  # type: ctypes.Array[ctypes.c_int64] # type: ignore
 | 
			
		||||
    scale_search: ctypes.c_bool,
 | 
			
		||||
    weights,  # type: ctypes.Array[ctypes.c_float] # type: ignore
 | 
			
		||||
) -> int:
 | 
			
		||||
    return _lib.ggml_quantize_tensor_rtn_with_weights(src, dst, scale_ptr, qtype, n, k,
 | 
			
		||||
                                                      hist, scale_search, weights)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_lib.ggml_quantize_tensor_rtn_with_weights.argtypes = [
 | 
			
		||||
    ctypes.POINTER(ctypes.c_float),
 | 
			
		||||
    ctypes.c_void_p,
 | 
			
		||||
    ctypes.POINTER(ctypes.c_float),
 | 
			
		||||
    ctypes.c_int,
 | 
			
		||||
    ctypes.c_size_t,
 | 
			
		||||
    ctypes.c_int,
 | 
			
		||||
    ctypes.POINTER(ctypes.c_int64),
 | 
			
		||||
    ctypes.c_bool,
 | 
			
		||||
    ctypes.POINTER(ctypes.c_float),
 | 
			
		||||
]
 | 
			
		||||
_lib.ggml_quantize_tensor_rtn_with_weights.restype = ctypes.c_size_t
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ggml_type_size(qtype: ctypes.c_int) -> int:
 | 
			
		||||
    return _lib.ggml_type_size(qtype)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -246,8 +246,17 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
 | 
			
		|||
        if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
 | 
			
		||||
            if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
 | 
			
		||||
                scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
 | 
			
		||||
                ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
 | 
			
		||||
                                              k, hist, enable_scale_search)
 | 
			
		||||
                if imatrix is None:
 | 
			
		||||
                    ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
 | 
			
		||||
                                                  k, hist, enable_scale_search)
 | 
			
		||||
                else:
 | 
			
		||||
                    imatrix = imatrix.data.data_ptr()
 | 
			
		||||
                    imatrix = ctypes.cast(imatrix, ctypes.POINTER(ctypes.c_float))
 | 
			
		||||
                    ggml.ggml_quantize_tensor_rtn_with_weights(src, dst, scale_ptr,
 | 
			
		||||
                                                               qtype, n,
 | 
			
		||||
                                                               k, hist,
 | 
			
		||||
                                                               enable_scale_search,
 | 
			
		||||
                                                               imatrix)
 | 
			
		||||
                return dst_tensor, scale.type(torch.float16)
 | 
			
		||||
            else:
 | 
			
		||||
                ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,7 +26,7 @@ from transformers.dynamic_module_utils import get_imports
 | 
			
		|||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.utils import logger
 | 
			
		||||
from ipex_llm.transformers.utils import logger, load_imatrix_data
 | 
			
		||||
from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -137,6 +137,12 @@ class _BaseAutoModelClass:
 | 
			
		|||
        convert_model = kwargs.pop('convert_model', False)
 | 
			
		||||
        save_directory = kwargs.pop('save_directory', None)
 | 
			
		||||
        fuse_layers = kwargs.pop('fuse_layers', None)
 | 
			
		||||
        imatrix_file = kwargs.pop('imatrix_file', None)
 | 
			
		||||
 | 
			
		||||
        if imatrix_file is not None:
 | 
			
		||||
            imatrix_data = load_imatrix_data(imatrix_file)
 | 
			
		||||
        else:
 | 
			
		||||
            imatrix_data = None
 | 
			
		||||
 | 
			
		||||
        invalidInputError(
 | 
			
		||||
            quantization_group_size in [0, 32, 64, 128],
 | 
			
		||||
| 
						 | 
				
			
			@ -205,7 +211,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
                    "transpose_value_cache": transpose_value_cache,
 | 
			
		||||
                    "convert_model": convert_model,
 | 
			
		||||
                    "save_directory": save_directory,
 | 
			
		||||
                    "fuse_layers": fuse_layers
 | 
			
		||||
                    "fuse_layers": fuse_layers,
 | 
			
		||||
                    "imatrix_data": imatrix_data
 | 
			
		||||
                }
 | 
			
		||||
                model = cls.optimize_npu_model(*args, **optimize_kwargs)
 | 
			
		||||
            else:
 | 
			
		||||
| 
						 | 
				
			
			@ -213,7 +220,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
                optimize_llm(model)
 | 
			
		||||
                with torch.no_grad():
 | 
			
		||||
                    cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
 | 
			
		||||
                                     quantization_group_size, *args, **kwargs)
 | 
			
		||||
                                     quantization_group_size, imatrix_data=imatrix_data,
 | 
			
		||||
                                     *args, **kwargs)
 | 
			
		||||
                    if hasattr(model, "llm"):
 | 
			
		||||
                        create_npu_kernels(model.llm)
 | 
			
		||||
                    else:
 | 
			
		||||
| 
						 | 
				
			
			@ -246,6 +254,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        convert_model = kwargs.pop('convert_model', False)
 | 
			
		||||
        save_directory = kwargs.pop('save_directory', None)
 | 
			
		||||
        fuse_layers = kwargs.pop('fuse_layers', None)
 | 
			
		||||
        imatrix_data = kwargs.pop('imatrix_data', None)
 | 
			
		||||
 | 
			
		||||
        if hasattr(model, "llm"):
 | 
			
		||||
            llm = model.llm
 | 
			
		||||
| 
						 | 
				
			
			@ -258,7 +267,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
            optimize_llm_pre(model, qtype, mixed_precision,
 | 
			
		||||
                             quantization_group_size=quantization_group_size)
 | 
			
		||||
            cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
 | 
			
		||||
                             quantization_group_size, *args, **kwargs)
 | 
			
		||||
                             quantization_group_size, imatrix_data,
 | 
			
		||||
                             *args, **kwargs)
 | 
			
		||||
            create_npu_kernels(llm)
 | 
			
		||||
        model = model.eval()
 | 
			
		||||
        logger.info(f"Finish to convert model")
 | 
			
		||||
| 
						 | 
				
			
			@ -305,12 +315,12 @@ class _BaseAutoModelClass:
 | 
			
		|||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def load_convert(cls, q_k, optimize_model, device, modules_to_not_convert,
 | 
			
		||||
                     group_size=0, *arg, **kwarg):
 | 
			
		||||
                     group_size=0, imatrix_data=None, *arg, **kwarg):
 | 
			
		||||
        from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear
 | 
			
		||||
 | 
			
		||||
        replace_with_QuantizedLinear(optimize_model, q_k, device=device,
 | 
			
		||||
                                     modules_to_not_convert=modules_to_not_convert,
 | 
			
		||||
                                     group_size=group_size)
 | 
			
		||||
                                     group_size=group_size, imatrix=imatrix_data)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def load_convert_cpu(cls, q_k, optimize_model, device, modules_to_not_convert,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,11 +19,11 @@ import os
 | 
			
		|||
import torch
 | 
			
		||||
import importlib
 | 
			
		||||
from ipex_llm.transformers.npu_models.linear import QuantizedLinear
 | 
			
		||||
import tempfile
 | 
			
		||||
import time
 | 
			
		||||
from typing import Callable, List, Optional
 | 
			
		||||
from transformers import GenerationConfig, \
 | 
			
		||||
    LogitsProcessorList, StoppingCriteriaList
 | 
			
		||||
from ipex_llm.transformers.utils import module_name_process
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def module_optimization(func) -> torch.nn.Module:
 | 
			
		||||
| 
						 | 
				
			
			@ -39,7 +39,7 @@ def module_optimization(func) -> torch.nn.Module:
 | 
			
		|||
    """
 | 
			
		||||
 | 
			
		||||
    def wrapper(model: torch.nn.Module, qtype, device, modules_to_not_convert,
 | 
			
		||||
                group_size=0, *args, **kwargs):
 | 
			
		||||
                group_size=0, imatrix=None, full_name="", *args, **kwargs):
 | 
			
		||||
        """Recursively apply the optimization function.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
| 
						 | 
				
			
			@ -49,23 +49,40 @@ def module_optimization(func) -> torch.nn.Module:
 | 
			
		|||
 | 
			
		||||
        """
 | 
			
		||||
        for name, layer in model.named_children():
 | 
			
		||||
            if full_name == "":
 | 
			
		||||
                cur_full_name = name
 | 
			
		||||
            else:
 | 
			
		||||
                cur_full_name = full_name + "." + name
 | 
			
		||||
            cur_imatrix = None
 | 
			
		||||
            if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
 | 
			
		||||
                new_module_name, _, cur_module_name, dq_idx = module_name_process(cur_full_name)
 | 
			
		||||
                if imatrix is not None and new_module_name in imatrix:
 | 
			
		||||
                    cur_imatrix = imatrix[new_module_name]
 | 
			
		||||
                    if cur_imatrix.shape[0] != layer.weight.shape[1]:
 | 
			
		||||
                        ws = layer.weight.shape[1]
 | 
			
		||||
                        cur_imatrix = cur_imatrix[ws * dq_idx: ws * (dq_idx + 1)]
 | 
			
		||||
            if name not in modules_to_not_convert:
 | 
			
		||||
                new_layer = func(layer, qtype, device, modules_to_not_convert,
 | 
			
		||||
                                 group_size=group_size, *args, **kwargs)
 | 
			
		||||
                                 group_size=group_size, imatrix=cur_imatrix,
 | 
			
		||||
                                 *args, **kwargs)
 | 
			
		||||
                if new_layer:
 | 
			
		||||
                    model.add_module(name, new_layer)
 | 
			
		||||
                    wrapper(new_layer, qtype, device, modules_to_not_convert,
 | 
			
		||||
                            group_size=group_size, *args, **kwargs)
 | 
			
		||||
                            group_size=group_size, imatrix=imatrix,
 | 
			
		||||
                            full_name=cur_full_name,
 | 
			
		||||
                            *args, **kwargs)
 | 
			
		||||
                else:
 | 
			
		||||
                    wrapper(layer, qtype, device, modules_to_not_convert,
 | 
			
		||||
                            group_size=group_size, *args, **kwargs)
 | 
			
		||||
                            group_size=group_size, imatrix=imatrix,
 | 
			
		||||
                            full_name=cur_full_name,
 | 
			
		||||
                            *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    return wrapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@module_optimization
 | 
			
		||||
def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
 | 
			
		||||
                                 group_size):
 | 
			
		||||
                                 group_size, imatrix):
 | 
			
		||||
    from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
 | 
			
		||||
    from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
    iqtype = ggml_tensor_qtype[qtype]
 | 
			
		||||
| 
						 | 
				
			
			@ -79,7 +96,8 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
 | 
			
		|||
        enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
 | 
			
		||||
        qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
 | 
			
		||||
                                             iqtype, device=device,
 | 
			
		||||
                                             enable_scale_search=enable_scale_search)
 | 
			
		||||
                                             enable_scale_search=enable_scale_search,
 | 
			
		||||
                                             imatrix=imatrix)
 | 
			
		||||
        return QuantizedLinear(qweights, scale, layer.bias,
 | 
			
		||||
                               group_size=group_size)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -247,6 +247,10 @@ def module_name_process(full_module_name):
 | 
			
		|||
    else:
 | 
			
		||||
        super_module_name = None
 | 
			
		||||
    exp_id = None
 | 
			
		||||
    new_module_name = None
 | 
			
		||||
    layer = None
 | 
			
		||||
    cur_module = None
 | 
			
		||||
    dq_idx = None
 | 
			
		||||
    if super_module_name == 'block_sparse_moe':
 | 
			
		||||
        # handle mixtral moe here
 | 
			
		||||
        moe_mapping = {"w1": "gate", "w2": "down", "w3": "up"}
 | 
			
		||||
| 
						 | 
				
			
			@ -265,11 +269,24 @@ def module_name_process(full_module_name):
 | 
			
		|||
            layer = module_name_list[2]
 | 
			
		||||
            cur_module = module_name_list[-1][:-5]
 | 
			
		||||
            new_module_name = '_'.join([layer, cur_module])
 | 
			
		||||
        elif len(module_name_list) == 6 and 'dq' in module_name_list[-1]:
 | 
			
		||||
            # for NPU dq_list linear
 | 
			
		||||
            layer = module_name_list[2]
 | 
			
		||||
            cur_module = module_name_list[-1]
 | 
			
		||||
            try:
 | 
			
		||||
                dq_idx = int(cur_module[-2:])
 | 
			
		||||
            except:
 | 
			
		||||
                dq_idx = int(cur_module[-1:])
 | 
			
		||||
            if cur_module[0] in 'qkvo':
 | 
			
		||||
                cur_module = cur_module[0]
 | 
			
		||||
            elif cur_module[:2] == "up":
 | 
			
		||||
                cur_module = cur_module[:2]
 | 
			
		||||
            elif cur_module[:4] == "gate" or cur_module[:4] == "down":
 | 
			
		||||
                cur_module = cur_module[:4]
 | 
			
		||||
            new_module_name = '_'.join([layer, cur_module])
 | 
			
		||||
        elif len(module_name_list) == 1:
 | 
			
		||||
            new_module_name = module_name_list[0]
 | 
			
		||||
            layer = None
 | 
			
		||||
            cur_module = None
 | 
			
		||||
    return new_module_name, layer, cur_module
 | 
			
		||||
    return new_module_name, layer, cur_module, dq_idx
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_config=None):
 | 
			
		||||
| 
						 | 
				
			
			@ -283,7 +300,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
 | 
			
		|||
    if qtype in [ggml_tensor_qtype["gguf_iq2_xxs"], ggml_tensor_qtype["gguf_iq2_xs"],
 | 
			
		||||
                 ggml_tensor_qtype["gguf_iq1_s"]]:
 | 
			
		||||
        # For quantization which needs importance matrix
 | 
			
		||||
        new_module_name, layer, cur_module = module_name_process(full_module_name)
 | 
			
		||||
        new_module_name, layer, cur_module, _ = module_name_process(full_module_name)
 | 
			
		||||
        # custom mixed quantization strategy
 | 
			
		||||
        if model_type == "mixtral":
 | 
			
		||||
            if cur_module == 'v':
 | 
			
		||||
| 
						 | 
				
			
			@ -312,7 +329,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
 | 
			
		|||
            if new_module_name == 'lm_head':
 | 
			
		||||
                cur_qtype = ggml_tensor_qtype['sym_int8']
 | 
			
		||||
    elif qtype == ggml_tensor_qtype["q2_k"]:
 | 
			
		||||
        new_module_name, layer, cur_module = module_name_process(full_module_name)
 | 
			
		||||
        new_module_name, layer, cur_module, _ = module_name_process(full_module_name)
 | 
			
		||||
        if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
 | 
			
		||||
            # TODO: q2_k need others k-quants type here
 | 
			
		||||
            cur_qtype = ggml_tensor_qtype['q2_k']
 | 
			
		||||
| 
						 | 
				
			
			@ -325,7 +342,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
 | 
			
		|||
                cur_qtype = ggml_tensor_qtype['sym_int8']
 | 
			
		||||
    elif qtype > 100:
 | 
			
		||||
        # gguf mixed precision
 | 
			
		||||
        new_module_name, layer, cur_module = module_name_process(full_module_name)
 | 
			
		||||
        new_module_name, layer, cur_module, _ = module_name_process(full_module_name)
 | 
			
		||||
        num_hidden_layers = getattr(model_config, "num_hidden_layers", None)
 | 
			
		||||
        if qtype in [gguf_mixed_qtype["gguf_q4k_s"], gguf_mixed_qtype["gguf_q4k_m"]] and \
 | 
			
		||||
                new_module_name == 'lm_head':
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue