Init NPU quantize method and support q8_0_rtn (#11452)

* q8_0_rtn

* fix float point
This commit is contained in:
Zhao Changmin 2024-07-01 13:45:07 +08:00 committed by GitHub
parent 319a3b36b2
commit cf8eb7b128
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 101 additions and 7 deletions

View file

@ -991,6 +991,33 @@ _lib.ggml_quantize_tensor_with_weights.argtypes = [
_lib.ggml_quantize_tensor_with_weights.restype = ctypes.c_size_t
# GGML API
def ggml_quantize_tensor_rtn(
src, # type: ctypes.Array[ctypes.c_float] # type: ignore
dst: ctypes.c_void_p,
scale_ptr, # type: ctypes.Array[ctypes.c_float] # type: ignore
qtype: ctypes.c_int,
n: ctypes.c_size_t,
k: ctypes.c_int,
hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore
scale_search: ctypes.c_bool,
) -> int:
return _lib.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n, k, hist, scale_search)
_lib.ggml_quantize_tensor_rtn.argtypes = [
ctypes.POINTER(ctypes.c_float),
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_float),
ctypes.c_int,
ctypes.c_size_t,
ctypes.c_int,
ctypes.POINTER(ctypes.c_int64),
ctypes.c_bool,
]
_lib.ggml_quantize_tensor_rtn.restype = ctypes.c_size_t
def ggml_type_size(qtype: ctypes.c_int) -> int:
return _lib.ggml_type_size(qtype)

View file

@ -50,6 +50,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
"q5_k": 28,
"fp6": 29,
"fp6_k": 30,
"sym_int4_rtn": 31,
"sym_int8_rtn": 32,
}
# mixed precison from llama.cpp

View file

@ -81,6 +81,7 @@ Q4_K = ggml_tensor_qtype["q4_k"]
Q6_K = ggml_tensor_qtype["q6_k"]
Q5_K = ggml_tensor_qtype["q5_k"]
FP6_K = ggml_tensor_qtype["fp6_k"]
SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"]
# For sym_int4
@ -216,14 +217,27 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
f"Last dim of input tensor must be multiple of {QK}")
dst_size = (n // QK) * block_size_in_bytes
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
device=device)
if qtype in [SYM_INT8_RTN]:
dst_tensor = torch.empty(dst_size, dtype=torch.int8,
device=device)
scale = torch.empty(n // k, dtype=torch.float32,
device=device)
else:
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
device=device)
if not convert_shape_only and device != 'meta':
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
hist = (ctypes.c_int64 * 16)()
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
if qtype in [SYM_INT8_RTN]:
scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
k, hist, enable_scale_search)
dst_tensor = dst_tensor.reshape_as(tensor)
return dst_tensor, scale.type(torch.float16)
else:
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
else:
if imatrix is not None:
# quantize with importance matrix

View file

@ -77,7 +77,7 @@ class _BaseAutoModelClass:
from intel_npu_acceleration_library.dtypes import int8, int4
qtype_map = {
'sym_int4': int4,
'sym_int8': int8,
'sym_int8': "sym_int8_rtn",
'fp16': torch.half,
'fp32': torch.float,
}
@ -119,9 +119,12 @@ class _BaseAutoModelClass:
from intel_npu_acceleration_library.compiler import create_npu_kernels
with torch.no_grad():
optimize_llm(model)
if not qtype.is_floating_point:
model = quantize_model(model, qtype)
create_npu_kernels(model)
if qtype == "sym_int8_rtn":
cls.load_convert(qtype, model, *args, **kwargs)
else:
if not qtype.is_floating_point:
model = quantize_model(model, qtype)
create_npu_kernels(model)
model = model.eval()
except ImportError as _e:
# for intel_npu_acceleration_library < 1.1.0
@ -133,6 +136,11 @@ class _BaseAutoModelClass:
return model
@classmethod
def load_convert(cls, q_k, optimize_model, *arg, **kwarg):
from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear
replace_with_QuantizedLinear(optimize_model, q_k)
@staticmethod
def save_low_bit(self, model_dir: str, *args, **kwargs):
os.makedirs(model_dir, exist_ok=True)

View file

@ -15,6 +15,49 @@
import torch
from intel_npu_acceleration_library.nn import QuantizedLinear
def module_optimization(func) -> torch.nn.Module:
"""Optimize recursively a torch.nn.Module with a specific function.
The function `func` get called recursively to every module in the network.
Args:
func (Callable): optimization function
Returns:
torch.nn.Module: optimized module
"""
def wrapper(model: torch.nn.Module, qtype, *args, **kwargs):
"""Recursively apply the optimization function.
Args:
model (torch.nn.Module): original module
args (Any): positional arguments
kwargs (Any): keyword arguments
"""
for name, layer in model.named_children():
new_layer = func(layer, qtype, *args, **kwargs)
if new_layer:
model.add_module(name, new_layer)
wrapper(new_layer, qtype, *args, **kwargs)
else:
wrapper(layer, qtype, *args, **kwargs)
return wrapper
@module_optimization
def replace_with_QuantizedLinear(layer, qtype):
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
from ipex_llm.ggml.quantize import ggml_tensor_qtype
iqtype = ggml_tensor_qtype[qtype]
if isinstance(layer, torch.nn.Linear):
qweights, scale = ggml_convert_qtype(layer.weight.data, iqtype, 'cpu')
return QuantizedLinear(qweights, scale, layer.bias)
def convert_forward(m, target_m, new_forward):