[LLM] Support mixed_fp8 on Arc (#9415)

* ut gpu allocation memory fix

* support mix_8bit on arc

* rename mixed_4bit to mixed_fp4 and mixed_8bit to mixed_fp8

* revert unexpected changes

* revert unexpected changes

* unify common logits

* rename in llm xmx_checker

* fix typo error and re-unify
This commit is contained in:
SONG Ge 2023-11-13 09:26:30 +08:00 committed by GitHub
parent ac7fbe77e2
commit 2888818b3a
4 changed files with 36 additions and 17 deletions

View file

@ -35,7 +35,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
"fp16": 12,
"fp8": 15,
"fp4": 16,
"mixed_4bit": 17} # Mixture of Formats Quantization 4 bits
"mixed_fp4": 17, # Mixture of Formats Quantization 4 bits
"mixed_fp8": 18} # Mixture of Formats Quantization 8 bits
_llama_quantize_type = {"q4_0": 2,
"q4_1": 3,

View file

@ -66,7 +66,8 @@ NF4 = ggml_tensor_qtype["nf4"]
NF3 = ggml_tensor_qtype["nf3"]
FP8 = ggml_tensor_qtype["fp8"]
FP4 = ggml_tensor_qtype["fp4"]
MOFQ4 = ggml_tensor_qtype["mixed_4bit"]
MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
@ -194,6 +195,16 @@ class FP4Params(torch.nn.Parameter):
self.convert_shape_only = convert_shape_only
return self
def ggml_mse(self, w, ggml_qtype, device):
from torch.nn.functional import mse_loss
w_quant = ggml_convert_qtype(w, ggml_qtype,
device=device,
convert_shape_only=self.convert_shape_only)
w_dequant = ggml_convert_fp32(w_quant, w.shape,
reduce(mul, w.shape, 1), ggml_qtype)
mse = mse_loss(w_dequant, w)
return mse, w_quant
def quantize(self, device=None):
if not self.quantized:
w = self.data.contiguous().float()
@ -206,25 +217,31 @@ class FP4Params(torch.nn.Parameter):
# save/load
self.qtype = SYM_INT4
else:
from torch.nn.functional import mse_loss
w_quant_q4_0 = ggml_convert_qtype(w, SYM_INT4,
device=device,
convert_shape_only=self.convert_shape_only)
w_q4_0_dequant = ggml_convert_fp32(w_quant_q4_0, w.shape,
reduce(mul, w.shape, 1), SYM_INT4)
w_quant_fp4 = ggml_convert_qtype(w, FP4,
device=device,
convert_shape_only=self.convert_shape_only)
w_fp4_dequant = ggml_convert_fp32(w_quant_fp4, w.shape,
reduce(mul, w.shape, 1), FP4)
q4_0_mse = mse_loss(w_q4_0_dequant, w)
fp4_mse = mse_loss(w_fp4_dequant, w)
q4_0_mse, w_quant_q4_0 = self.ggml_mse(w, SYM_INT4, device=device)
fp4_mse, w_quant_fp4 = self.ggml_mse(w, FP4, device=device)
if q4_0_mse <= fp4_mse:
self.qtype = SYM_INT4
self.data = w_quant_q4_0
else:
self.qtype = FP4
self.data = w_quant_fp4
elif self.qtype == MOFQ8:
if device == 'meta':
w_quantized = ggml_convert_qtype(w, SYM_INT8,
device=device,
convert_shape_only=self.convert_shape_only)
# TODO: should load from config, the current implementation doesn't support
# save/load
self.qtype = SYM_INT8
else:
q8_0_mse, w_quant_q8_0 = self.ggml_mse(w, SYM_INT8, device=device)
fp8_mse, w_quant_fp8 = self.ggml_mse(w, FP8, device=device)
if q8_0_mse <= fp8_mse:
self.qtype = SYM_INT8
self.data = w_quant_q8_0
else:
self.qtype = FP8
self.data = w_quant_fp8
else:
w_quantized = ggml_convert_qtype(w, self.qtype,
device=device,

View file

@ -113,7 +113,7 @@ class _BaseAutoModelClass:
invalidInputError(q_k in ggml_tensor_qtype,
f"Unknown load_in_low_bit value: {q_k}, expected:"
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
"fp4, fp8, fp16 or mixed_4bit.")
"fp4, fp8, fp16, mixed_fp4 or mixed_fp8.")
qtype = ggml_tensor_qtype[q_k]
# In case it needs a second try,

View file

@ -25,7 +25,8 @@ NF4 = ggml_tensor_qtype["nf4"]
NF3 = ggml_tensor_qtype["nf3"]
FP8 = ggml_tensor_qtype["fp8"]
FP4 = ggml_tensor_qtype["fp4"]
MOFQ4 = ggml_tensor_qtype["mixed_4bit"]
MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
class XMXChecker: