Support MoFQ4 on arc (#9301)

* init

* update

* fix style

* fix style

* fix style

* meet comments
This commit is contained in:
Yina Chen 2023-11-01 10:59:46 +08:00 committed by GitHub
parent 8ef8e25178
commit 2262ae4d13
4 changed files with 68 additions and 6 deletions

View file

@ -999,6 +999,24 @@ _lib.ggml_dequantize_q4_0.argtypes = [
_lib.ggml_quantize_q4_0.restype = None
def ggml_dequantize(
src: ctypes.c_void_p,
dst: ctypes.c_void_p,
k: ctypes.c_size_t,
qtype: ctypes.c_int
):
_lib.ggml_dequantize(src, dst, k, qtype)
_lib.ggml_dequantize.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
ctypes.c_int
]
_lib.ggml_dequantize.restype = None
def ggml_q_format_convet_cpu2xpu(
src: ctypes.c_void_p,
dst: ctypes.c_void_p,

View file

@ -34,7 +34,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
"nf3": 11,
"fp16": 12,
"fp8": 15,
"fp4": 16}
"fp4": 16,
"mixed_4bit": 17} # Mixture of Formats Quantization 4 bits
_llama_quantize_type = {"q4_0": 2,
"q4_1": 3,

View file

@ -66,6 +66,7 @@ NF4 = ggml_tensor_qtype["nf4"]
NF3 = ggml_tensor_qtype["nf3"]
FP8 = ggml_tensor_qtype["fp8"]
FP4 = ggml_tensor_qtype["fp4"]
MOFQ4 = ggml_tensor_qtype["mixed_4bit"]
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
@ -158,6 +159,19 @@ def ggml_int4_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int):
return dst_tensor
def ggml_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int, qtype: int):
invalidInputError(tensor.dtype == torch.uint8,
"Input tensor must be uint8")
src_ptr = ctypes.c_void_p(tensor.data.data_ptr())
dst_size = k
dst_tensor = torch.empty(weight_shape, dtype=torch.float)
dst_ptr = ctypes.c_void_p(dst_tensor.data.data_ptr())
ggml.ggml_dequantize(src_ptr, dst_ptr, k, qtype)
return dst_tensor
# Rename to FP4Params to trigger initializing
# the params layer with all parameters on the CPU
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333
@ -183,10 +197,39 @@ class FP4Params(torch.nn.Parameter):
def quantize(self, device=None):
if not self.quantized:
w = self.data.contiguous().float()
w_quantized = ggml_convert_qtype(w, self.qtype,
device=device,
convert_shape_only=self.convert_shape_only)
self.data = w_quantized
if self.qtype == MOFQ4:
if device == 'meta':
w_quantized = ggml_convert_qtype(w, SYM_INT4,
device=device,
convert_shape_only=self.convert_shape_only)
# TODO: should load from config, the current implementation doesn't support
# save/load
self.qtype = SYM_INT4
else:
from torch.nn.functional import mse_loss
w_quant_q4_0 = ggml_convert_qtype(w, SYM_INT4,
device=device,
convert_shape_only=self.convert_shape_only)
w_q4_0_dequant = ggml_convert_fp32(w_quant_q4_0, w.shape,
reduce(mul, w.shape, 1), SYM_INT4)
w_quant_fp4 = ggml_convert_qtype(w, FP4,
device=device,
convert_shape_only=self.convert_shape_only)
w_fp4_dequant = ggml_convert_fp32(w_quant_fp4, w.shape,
reduce(mul, w.shape, 1), FP4)
q4_0_mse = mse_loss(w_q4_0_dequant, w)
fp4_mse = mse_loss(w_fp4_dequant, w)
if q4_0_mse <= fp4_mse:
self.qtype = SYM_INT4
self.data = w_quant_q4_0
else:
self.qtype = FP4
self.data = w_quant_fp4
else:
w_quantized = ggml_convert_qtype(w, self.qtype,
device=device,
convert_shape_only=self.convert_shape_only)
self.data = w_quantized
self.quantized = True
self._shape = w.shape
return self

View file

@ -111,7 +111,7 @@ class _BaseAutoModelClass:
invalidInputError(q_k in ggml_tensor_qtype,
f"Unknown load_in_low_bit value: {q_k}, expected:"
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
"fp4, fp8 or fp16.")
"fp4, fp8, fp16 or mixed_4bit.")
qtype = ggml_tensor_qtype[q_k]
# In case it needs a second try,