Init NPU quantize method and support q8_0_rtn (#11452)
* q8_0_rtn * fix float point
This commit is contained in:
parent
319a3b36b2
commit
cf8eb7b128
5 changed files with 101 additions and 7 deletions
|
|
@ -991,6 +991,33 @@ _lib.ggml_quantize_tensor_with_weights.argtypes = [
|
||||||
_lib.ggml_quantize_tensor_with_weights.restype = ctypes.c_size_t
|
_lib.ggml_quantize_tensor_with_weights.restype = ctypes.c_size_t
|
||||||
|
|
||||||
|
|
||||||
|
# GGML API
|
||||||
|
def ggml_quantize_tensor_rtn(
|
||||||
|
src, # type: ctypes.Array[ctypes.c_float] # type: ignore
|
||||||
|
dst: ctypes.c_void_p,
|
||||||
|
scale_ptr, # type: ctypes.Array[ctypes.c_float] # type: ignore
|
||||||
|
qtype: ctypes.c_int,
|
||||||
|
n: ctypes.c_size_t,
|
||||||
|
k: ctypes.c_int,
|
||||||
|
hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore
|
||||||
|
scale_search: ctypes.c_bool,
|
||||||
|
) -> int:
|
||||||
|
return _lib.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n, k, hist, scale_search)
|
||||||
|
|
||||||
|
|
||||||
|
_lib.ggml_quantize_tensor_rtn.argtypes = [
|
||||||
|
ctypes.POINTER(ctypes.c_float),
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.POINTER(ctypes.c_float),
|
||||||
|
ctypes.c_int,
|
||||||
|
ctypes.c_size_t,
|
||||||
|
ctypes.c_int,
|
||||||
|
ctypes.POINTER(ctypes.c_int64),
|
||||||
|
ctypes.c_bool,
|
||||||
|
]
|
||||||
|
_lib.ggml_quantize_tensor_rtn.restype = ctypes.c_size_t
|
||||||
|
|
||||||
|
|
||||||
def ggml_type_size(qtype: ctypes.c_int) -> int:
|
def ggml_type_size(qtype: ctypes.c_int) -> int:
|
||||||
return _lib.ggml_type_size(qtype)
|
return _lib.ggml_type_size(qtype)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
|
||||||
"q5_k": 28,
|
"q5_k": 28,
|
||||||
"fp6": 29,
|
"fp6": 29,
|
||||||
"fp6_k": 30,
|
"fp6_k": 30,
|
||||||
|
"sym_int4_rtn": 31,
|
||||||
|
"sym_int8_rtn": 32,
|
||||||
}
|
}
|
||||||
|
|
||||||
# mixed precison from llama.cpp
|
# mixed precison from llama.cpp
|
||||||
|
|
|
||||||
|
|
@ -81,6 +81,7 @@ Q4_K = ggml_tensor_qtype["q4_k"]
|
||||||
Q6_K = ggml_tensor_qtype["q6_k"]
|
Q6_K = ggml_tensor_qtype["q6_k"]
|
||||||
Q5_K = ggml_tensor_qtype["q5_k"]
|
Q5_K = ggml_tensor_qtype["q5_k"]
|
||||||
FP6_K = ggml_tensor_qtype["fp6_k"]
|
FP6_K = ggml_tensor_qtype["fp6_k"]
|
||||||
|
SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"]
|
||||||
|
|
||||||
|
|
||||||
# For sym_int4
|
# For sym_int4
|
||||||
|
|
@ -216,14 +217,27 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||||
f"Last dim of input tensor must be multiple of {QK}")
|
f"Last dim of input tensor must be multiple of {QK}")
|
||||||
|
|
||||||
dst_size = (n // QK) * block_size_in_bytes
|
dst_size = (n // QK) * block_size_in_bytes
|
||||||
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
|
if qtype in [SYM_INT8_RTN]:
|
||||||
device=device)
|
dst_tensor = torch.empty(dst_size, dtype=torch.int8,
|
||||||
|
device=device)
|
||||||
|
scale = torch.empty(n // k, dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
else:
|
||||||
|
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
|
||||||
|
device=device)
|
||||||
|
|
||||||
if not convert_shape_only and device != 'meta':
|
if not convert_shape_only and device != 'meta':
|
||||||
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
||||||
hist = (ctypes.c_int64 * 16)()
|
hist = (ctypes.c_int64 * 16)()
|
||||||
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
|
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
|
||||||
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
|
if qtype in [SYM_INT8_RTN]:
|
||||||
|
scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
|
||||||
|
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
|
||||||
|
k, hist, enable_scale_search)
|
||||||
|
dst_tensor = dst_tensor.reshape_as(tensor)
|
||||||
|
return dst_tensor, scale.type(torch.float16)
|
||||||
|
else:
|
||||||
|
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
|
||||||
else:
|
else:
|
||||||
if imatrix is not None:
|
if imatrix is not None:
|
||||||
# quantize with importance matrix
|
# quantize with importance matrix
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,7 @@ class _BaseAutoModelClass:
|
||||||
from intel_npu_acceleration_library.dtypes import int8, int4
|
from intel_npu_acceleration_library.dtypes import int8, int4
|
||||||
qtype_map = {
|
qtype_map = {
|
||||||
'sym_int4': int4,
|
'sym_int4': int4,
|
||||||
'sym_int8': int8,
|
'sym_int8': "sym_int8_rtn",
|
||||||
'fp16': torch.half,
|
'fp16': torch.half,
|
||||||
'fp32': torch.float,
|
'fp32': torch.float,
|
||||||
}
|
}
|
||||||
|
|
@ -119,9 +119,12 @@ class _BaseAutoModelClass:
|
||||||
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
optimize_llm(model)
|
optimize_llm(model)
|
||||||
if not qtype.is_floating_point:
|
if qtype == "sym_int8_rtn":
|
||||||
model = quantize_model(model, qtype)
|
cls.load_convert(qtype, model, *args, **kwargs)
|
||||||
create_npu_kernels(model)
|
else:
|
||||||
|
if not qtype.is_floating_point:
|
||||||
|
model = quantize_model(model, qtype)
|
||||||
|
create_npu_kernels(model)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
except ImportError as _e:
|
except ImportError as _e:
|
||||||
# for intel_npu_acceleration_library < 1.1.0
|
# for intel_npu_acceleration_library < 1.1.0
|
||||||
|
|
@ -133,6 +136,11 @@ class _BaseAutoModelClass:
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_convert(cls, q_k, optimize_model, *arg, **kwarg):
|
||||||
|
from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear
|
||||||
|
replace_with_QuantizedLinear(optimize_model, q_k)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_low_bit(self, model_dir: str, *args, **kwargs):
|
def save_low_bit(self, model_dir: str, *args, **kwargs):
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,49 @@
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from intel_npu_acceleration_library.nn import QuantizedLinear
|
||||||
|
|
||||||
|
|
||||||
|
def module_optimization(func) -> torch.nn.Module:
|
||||||
|
"""Optimize recursively a torch.nn.Module with a specific function.
|
||||||
|
|
||||||
|
The function `func` get called recursively to every module in the network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (Callable): optimization function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.nn.Module: optimized module
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(model: torch.nn.Module, qtype, *args, **kwargs):
|
||||||
|
"""Recursively apply the optimization function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): original module
|
||||||
|
args (Any): positional arguments
|
||||||
|
kwargs (Any): keyword arguments
|
||||||
|
|
||||||
|
"""
|
||||||
|
for name, layer in model.named_children():
|
||||||
|
new_layer = func(layer, qtype, *args, **kwargs)
|
||||||
|
if new_layer:
|
||||||
|
model.add_module(name, new_layer)
|
||||||
|
wrapper(new_layer, qtype, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
wrapper(layer, qtype, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@module_optimization
|
||||||
|
def replace_with_QuantizedLinear(layer, qtype):
|
||||||
|
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
|
||||||
|
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
iqtype = ggml_tensor_qtype[qtype]
|
||||||
|
if isinstance(layer, torch.nn.Linear):
|
||||||
|
qweights, scale = ggml_convert_qtype(layer.weight.data, iqtype, 'cpu')
|
||||||
|
return QuantizedLinear(qweights, scale, layer.bias)
|
||||||
|
|
||||||
|
|
||||||
def convert_forward(m, target_m, new_forward):
|
def convert_forward(m, target_m, new_forward):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue