[LLM] Support mixed_fp8 on Arc (#9415)
* ut gpu allocation memory fix * support mix_8bit on arc * rename mixed_4bit to mixed_fp4 and mixed_8bit to mixed_fp8 * revert unexpected changes * revert unexpected changes * unify common logits * rename in llm xmx_checker * fix typo error and re-unify
This commit is contained in:
parent
ac7fbe77e2
commit
2888818b3a
4 changed files with 36 additions and 17 deletions
|
|
@ -35,7 +35,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
|
|||
"fp16": 12,
|
||||
"fp8": 15,
|
||||
"fp4": 16,
|
||||
"mixed_4bit": 17} # Mixture of Formats Quantization 4 bits
|
||||
"mixed_fp4": 17, # Mixture of Formats Quantization 4 bits
|
||||
"mixed_fp8": 18} # Mixture of Formats Quantization 8 bits
|
||||
|
||||
_llama_quantize_type = {"q4_0": 2,
|
||||
"q4_1": 3,
|
||||
|
|
|
|||
|
|
@ -66,7 +66,8 @@ NF4 = ggml_tensor_qtype["nf4"]
|
|||
NF3 = ggml_tensor_qtype["nf3"]
|
||||
FP8 = ggml_tensor_qtype["fp8"]
|
||||
FP4 = ggml_tensor_qtype["fp4"]
|
||||
MOFQ4 = ggml_tensor_qtype["mixed_4bit"]
|
||||
MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
|
||||
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
|
||||
|
||||
|
||||
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||
|
|
@ -194,6 +195,16 @@ class FP4Params(torch.nn.Parameter):
|
|||
self.convert_shape_only = convert_shape_only
|
||||
return self
|
||||
|
||||
def ggml_mse(self, w, ggml_qtype, device):
|
||||
from torch.nn.functional import mse_loss
|
||||
w_quant = ggml_convert_qtype(w, ggml_qtype,
|
||||
device=device,
|
||||
convert_shape_only=self.convert_shape_only)
|
||||
w_dequant = ggml_convert_fp32(w_quant, w.shape,
|
||||
reduce(mul, w.shape, 1), ggml_qtype)
|
||||
mse = mse_loss(w_dequant, w)
|
||||
return mse, w_quant
|
||||
|
||||
def quantize(self, device=None):
|
||||
if not self.quantized:
|
||||
w = self.data.contiguous().float()
|
||||
|
|
@ -206,25 +217,31 @@ class FP4Params(torch.nn.Parameter):
|
|||
# save/load
|
||||
self.qtype = SYM_INT4
|
||||
else:
|
||||
from torch.nn.functional import mse_loss
|
||||
w_quant_q4_0 = ggml_convert_qtype(w, SYM_INT4,
|
||||
device=device,
|
||||
convert_shape_only=self.convert_shape_only)
|
||||
w_q4_0_dequant = ggml_convert_fp32(w_quant_q4_0, w.shape,
|
||||
reduce(mul, w.shape, 1), SYM_INT4)
|
||||
w_quant_fp4 = ggml_convert_qtype(w, FP4,
|
||||
device=device,
|
||||
convert_shape_only=self.convert_shape_only)
|
||||
w_fp4_dequant = ggml_convert_fp32(w_quant_fp4, w.shape,
|
||||
reduce(mul, w.shape, 1), FP4)
|
||||
q4_0_mse = mse_loss(w_q4_0_dequant, w)
|
||||
fp4_mse = mse_loss(w_fp4_dequant, w)
|
||||
q4_0_mse, w_quant_q4_0 = self.ggml_mse(w, SYM_INT4, device=device)
|
||||
fp4_mse, w_quant_fp4 = self.ggml_mse(w, FP4, device=device)
|
||||
if q4_0_mse <= fp4_mse:
|
||||
self.qtype = SYM_INT4
|
||||
self.data = w_quant_q4_0
|
||||
else:
|
||||
self.qtype = FP4
|
||||
self.data = w_quant_fp4
|
||||
elif self.qtype == MOFQ8:
|
||||
if device == 'meta':
|
||||
w_quantized = ggml_convert_qtype(w, SYM_INT8,
|
||||
device=device,
|
||||
convert_shape_only=self.convert_shape_only)
|
||||
# TODO: should load from config, the current implementation doesn't support
|
||||
# save/load
|
||||
self.qtype = SYM_INT8
|
||||
else:
|
||||
q8_0_mse, w_quant_q8_0 = self.ggml_mse(w, SYM_INT8, device=device)
|
||||
fp8_mse, w_quant_fp8 = self.ggml_mse(w, FP8, device=device)
|
||||
if q8_0_mse <= fp8_mse:
|
||||
self.qtype = SYM_INT8
|
||||
self.data = w_quant_q8_0
|
||||
else:
|
||||
self.qtype = FP8
|
||||
self.data = w_quant_fp8
|
||||
else:
|
||||
w_quantized = ggml_convert_qtype(w, self.qtype,
|
||||
device=device,
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ class _BaseAutoModelClass:
|
|||
invalidInputError(q_k in ggml_tensor_qtype,
|
||||
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
||||
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
|
||||
"fp4, fp8, fp16 or mixed_4bit.")
|
||||
"fp4, fp8, fp16, mixed_fp4 or mixed_fp8.")
|
||||
qtype = ggml_tensor_qtype[q_k]
|
||||
|
||||
# In case it needs a second try,
|
||||
|
|
|
|||
|
|
@ -25,7 +25,8 @@ NF4 = ggml_tensor_qtype["nf4"]
|
|||
NF3 = ggml_tensor_qtype["nf3"]
|
||||
FP8 = ggml_tensor_qtype["fp8"]
|
||||
FP4 = ggml_tensor_qtype["fp4"]
|
||||
MOFQ4 = ggml_tensor_qtype["mixed_4bit"]
|
||||
MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
|
||||
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
|
||||
|
||||
|
||||
class XMXChecker:
|
||||
|
|
|
|||
Loading…
Reference in a new issue