Support MoFQ4 on arc (#9301)
* init * update * fix style * fix style * fix style * meet comments
This commit is contained in:
parent
8ef8e25178
commit
2262ae4d13
4 changed files with 68 additions and 6 deletions
|
|
@ -999,6 +999,24 @@ _lib.ggml_dequantize_q4_0.argtypes = [
|
|||
_lib.ggml_quantize_q4_0.restype = None
|
||||
|
||||
|
||||
def ggml_dequantize(
|
||||
src: ctypes.c_void_p,
|
||||
dst: ctypes.c_void_p,
|
||||
k: ctypes.c_size_t,
|
||||
qtype: ctypes.c_int
|
||||
):
|
||||
_lib.ggml_dequantize(src, dst, k, qtype)
|
||||
|
||||
|
||||
_lib.ggml_dequantize.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_size_t,
|
||||
ctypes.c_int
|
||||
]
|
||||
_lib.ggml_dequantize.restype = None
|
||||
|
||||
|
||||
def ggml_q_format_convet_cpu2xpu(
|
||||
src: ctypes.c_void_p,
|
||||
dst: ctypes.c_void_p,
|
||||
|
|
|
|||
|
|
@ -34,7 +34,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
|
|||
"nf3": 11,
|
||||
"fp16": 12,
|
||||
"fp8": 15,
|
||||
"fp4": 16}
|
||||
"fp4": 16,
|
||||
"mixed_4bit": 17} # Mixture of Formats Quantization 4 bits
|
||||
|
||||
_llama_quantize_type = {"q4_0": 2,
|
||||
"q4_1": 3,
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ NF4 = ggml_tensor_qtype["nf4"]
|
|||
NF3 = ggml_tensor_qtype["nf3"]
|
||||
FP8 = ggml_tensor_qtype["fp8"]
|
||||
FP4 = ggml_tensor_qtype["fp4"]
|
||||
MOFQ4 = ggml_tensor_qtype["mixed_4bit"]
|
||||
|
||||
|
||||
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||
|
|
@ -158,6 +159,19 @@ def ggml_int4_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int):
|
|||
return dst_tensor
|
||||
|
||||
|
||||
def ggml_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int, qtype: int):
|
||||
invalidInputError(tensor.dtype == torch.uint8,
|
||||
"Input tensor must be uint8")
|
||||
src_ptr = ctypes.c_void_p(tensor.data.data_ptr())
|
||||
|
||||
dst_size = k
|
||||
dst_tensor = torch.empty(weight_shape, dtype=torch.float)
|
||||
dst_ptr = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
||||
|
||||
ggml.ggml_dequantize(src_ptr, dst_ptr, k, qtype)
|
||||
return dst_tensor
|
||||
|
||||
|
||||
# Rename to FP4Params to trigger initializing
|
||||
# the params layer with all parameters on the CPU
|
||||
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333
|
||||
|
|
@ -183,10 +197,39 @@ class FP4Params(torch.nn.Parameter):
|
|||
def quantize(self, device=None):
|
||||
if not self.quantized:
|
||||
w = self.data.contiguous().float()
|
||||
w_quantized = ggml_convert_qtype(w, self.qtype,
|
||||
device=device,
|
||||
convert_shape_only=self.convert_shape_only)
|
||||
self.data = w_quantized
|
||||
if self.qtype == MOFQ4:
|
||||
if device == 'meta':
|
||||
w_quantized = ggml_convert_qtype(w, SYM_INT4,
|
||||
device=device,
|
||||
convert_shape_only=self.convert_shape_only)
|
||||
# TODO: should load from config, the current implementation doesn't support
|
||||
# save/load
|
||||
self.qtype = SYM_INT4
|
||||
else:
|
||||
from torch.nn.functional import mse_loss
|
||||
w_quant_q4_0 = ggml_convert_qtype(w, SYM_INT4,
|
||||
device=device,
|
||||
convert_shape_only=self.convert_shape_only)
|
||||
w_q4_0_dequant = ggml_convert_fp32(w_quant_q4_0, w.shape,
|
||||
reduce(mul, w.shape, 1), SYM_INT4)
|
||||
w_quant_fp4 = ggml_convert_qtype(w, FP4,
|
||||
device=device,
|
||||
convert_shape_only=self.convert_shape_only)
|
||||
w_fp4_dequant = ggml_convert_fp32(w_quant_fp4, w.shape,
|
||||
reduce(mul, w.shape, 1), FP4)
|
||||
q4_0_mse = mse_loss(w_q4_0_dequant, w)
|
||||
fp4_mse = mse_loss(w_fp4_dequant, w)
|
||||
if q4_0_mse <= fp4_mse:
|
||||
self.qtype = SYM_INT4
|
||||
self.data = w_quant_q4_0
|
||||
else:
|
||||
self.qtype = FP4
|
||||
self.data = w_quant_fp4
|
||||
else:
|
||||
w_quantized = ggml_convert_qtype(w, self.qtype,
|
||||
device=device,
|
||||
convert_shape_only=self.convert_shape_only)
|
||||
self.data = w_quantized
|
||||
self.quantized = True
|
||||
self._shape = w.shape
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ class _BaseAutoModelClass:
|
|||
invalidInputError(q_k in ggml_tensor_qtype,
|
||||
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
||||
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
|
||||
"fp4, fp8 or fp16.")
|
||||
"fp4, fp8, fp16 or mixed_4bit.")
|
||||
qtype = ggml_tensor_qtype[q_k]
|
||||
|
||||
# In case it needs a second try,
|
||||
|
|
|
|||
Loading…
Reference in a new issue