Support fp8 e5m2 on arc (#9711)

* init

* fix style

* update

* fix style

* update
This commit is contained in:
Yina Chen 2023-12-20 16:26:17 +08:00 committed by GitHub
parent e54c428d30
commit cd652a1710
3 changed files with 17 additions and 14 deletions

View file

@ -33,10 +33,12 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
"nf4": 10, "nf4": 10,
"nf3": 11, "nf3": 11,
"fp16": 12, "fp16": 12,
"fp8": 15, "fp8_e4m3": 15, # fp8 in e4m3 format
"fp4": 16, "fp4": 16,
"mixed_fp4": 17, # Mixture of Formats Quantization 4 bits "mixed_fp4": 17, # Mixture of Formats Quantization 4 bits
"mixed_fp8": 18} # Mixture of Formats Quantization 8 bits "mixed_fp8": 18, # Mixture of Formats Quantization 8 bits
"fp8_e5m2": 19, # fp8 in e5m2 format
"fp8": 15} # fp8 in e4m3 format
_llama_quantize_type = {"q4_0": 2, _llama_quantize_type = {"q4_0": 2,
"q4_1": 3, "q4_1": 3,

View file

@ -66,10 +66,11 @@ ASYM_INT4 = ggml_tensor_qtype["asym_int4"]
SYM_INT8 = ggml_tensor_qtype["sym_int8"] SYM_INT8 = ggml_tensor_qtype["sym_int8"]
NF4 = ggml_tensor_qtype["nf4"] NF4 = ggml_tensor_qtype["nf4"]
NF3 = ggml_tensor_qtype["nf3"] NF3 = ggml_tensor_qtype["nf3"]
FP8 = ggml_tensor_qtype["fp8"] FP8E4 = ggml_tensor_qtype["fp8_e4m3"]
FP4 = ggml_tensor_qtype["fp4"] FP4 = ggml_tensor_qtype["fp4"]
MOFQ4 = ggml_tensor_qtype["mixed_fp4"] MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
MOFQ8 = ggml_tensor_qtype["mixed_fp8"] MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
FP8E5 = ggml_tensor_qtype["fp8_e5m2"]
def get_block_size(qtype: str): def get_block_size(qtype: str):
@ -117,7 +118,7 @@ def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int
src = ctypes.c_void_p(tensor.data.data_ptr()) src = ctypes.c_void_p(tensor.data.data_ptr())
if qtype in [SYM_INT4, ASYM_INT4, SYM_INT8, NF4, NF3, FP4, FP8]: if qtype in [SYM_INT4, ASYM_INT4, SYM_INT8, NF4, NF3, FP4, FP8E4, FP8E5]:
dst_tensor = torch.empty_like(tensor) dst_tensor = torch.empty_like(tensor)
elif qtype == ggml_tensor_qtype["sym_int5"]: elif qtype == ggml_tensor_qtype["sym_int5"]:
QK = ggml.ggml_qk_size(qtype) QK = ggml.ggml_qk_size(qtype)
@ -142,7 +143,7 @@ def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int
src = ctypes.c_void_p(tensor.data.data_ptr()) src = ctypes.c_void_p(tensor.data.data_ptr())
if qtype in [SYM_INT4, ASYM_INT4, SYM_INT8, NF4, NF3, FP4, FP8]: if qtype in [SYM_INT4, ASYM_INT4, SYM_INT8, NF4, NF3, FP4, FP8E4, FP8E5]:
dst_tensor = torch.empty_like(tensor) dst_tensor = torch.empty_like(tensor)
elif qtype == ggml_tensor_qtype["sym_int5"]: elif qtype == ggml_tensor_qtype["sym_int5"]:
QK = ggml.ggml_qk_size(ggml_tensor_qtype["asym_int5"]) QK = ggml.ggml_qk_size(ggml_tensor_qtype["asym_int5"])
@ -245,12 +246,12 @@ class FP4Params(torch.nn.Parameter):
self.qtype = SYM_INT8 self.qtype = SYM_INT8
else: else:
q8_0_mse, w_quant_q8_0 = self.ggml_mse(w, SYM_INT8, device=device) q8_0_mse, w_quant_q8_0 = self.ggml_mse(w, SYM_INT8, device=device)
fp8_mse, w_quant_fp8 = self.ggml_mse(w, FP8, device=device) fp8_mse, w_quant_fp8 = self.ggml_mse(w, FP8E4, device=device)
if q8_0_mse <= fp8_mse: if q8_0_mse <= fp8_mse:
self.qtype = SYM_INT8 self.qtype = SYM_INT8
self.data = w_quant_q8_0 self.data = w_quant_q8_0
else: else:
self.qtype = FP8 self.qtype = FP8E4
self.data = w_quant_fp8 self.data = w_quant_fp8
else: else:
w_quantized = ggml_convert_qtype(w, self.qtype, w_quantized = ggml_convert_qtype(w, self.qtype,
@ -510,8 +511,8 @@ class LowBitLinear(nn.Linear):
else: else:
# CPU logic # CPU logic
# todo may need to set a different number on different platforms # todo may need to set a different number on different platforms
invalidInputError(self.qtype != NF3 and self.qtype != NF4 and self.qtype != FP8 invalidInputError(self.qtype != NF3 and self.qtype != NF4 and self.qtype != FP8E4
and self.qtype != FP4, and self.qtype != FP4 and self.qtype != FP8E5,
"NF3, NF4, FP4 and FP8 quantization are currently not" "NF3, NF4, FP4 and FP8 quantization are currently not"
" supported on CPU") " supported on CPU")
if self.training and x.requires_grad: if self.training and x.requires_grad:

View file

@ -91,10 +91,10 @@ class _BaseAutoModelClass:
if the model is GPTQ model. if the model is GPTQ model.
Default to be False. Default to be False.
:param load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5 :param load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5
, sym_int8, nf3, nf4, fp4, fp8 or fp16. sym_int4 means symmetric , sym_int8, nf3, nf4, fp4, fp8, fp8_e4m3, fp8_e5m2 or fp16.
int 4, asym_int4 means asymmetric int 4, nf4 means 4-bit sym_int4 means symmetric int 4, asym_int4 means asymmetric int 4,
NormalFloat, etc. Relevant low bit optimizations will be applied nf4 means 4-bit NormalFloat, etc. Relevant low bit optimizations
to the model. will be applied to the model.
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model. :param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
Default to be True. Default to be True.
:param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when :param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when
@ -216,7 +216,7 @@ class _BaseAutoModelClass:
invalidInputError(q_k in ggml_tensor_qtype, invalidInputError(q_k in ggml_tensor_qtype,
f"Unknown load_in_low_bit value: {q_k}, expected:" f"Unknown load_in_low_bit value: {q_k}, expected:"
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, " f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
"fp4, fp8, fp16, mixed_fp4 or mixed_fp8.") "fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, mixed_fp4 or mixed_fp8.")
qtype = ggml_tensor_qtype[q_k] qtype = ggml_tensor_qtype[q_k]
# In case it needs a second try, # In case it needs a second try,