Support torch_fp8 (#13196)

* support torch_fp8
This commit is contained in:
Yina Chen 2025-06-04 20:08:01 +08:00 committed by GitHub
parent 3accc31b86
commit e032156518
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 84 additions and 42 deletions

View file

@ -54,6 +54,9 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
"sym_int8_rtn": 32,
"asym_int4_rtn": 33,
"woq_int4": 34,
"torch_fp8_e5m2": 35,
"torch_fp8": 35,
"torch_fp8_e4m3": 36
}
# mixed precison from llama.cpp

View file

@ -86,6 +86,8 @@ SYM_INT4_RTN = ggml_tensor_qtype["sym_int4_rtn"]
SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"]
ASYM_INT4_RTN = ggml_tensor_qtype["asym_int4_rtn"]
WOQ_INT4 = ggml_tensor_qtype["woq_int4"]
TORCH_FP8E5 = ggml_tensor_qtype["torch_fp8_e5m2"]
TORCH_FP8E4 = ggml_tensor_qtype["torch_fp8_e4m3"]
RTN_DTYPE = {
SYM_INT4_RTN: torch.uint8,
ASYM_INT4_RTN: torch.uint8,
@ -106,6 +108,11 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
imatrix: torch.Tensor=None,
in_features: int=None,
enable_scale_search: bool=False):
if qtype in [TORCH_FP8E5, TORCH_FP8E4]:
fp8_dtype = torch.float8_e5m2 if qtype == TORCH_FP8E5 else torch.float8_e4m3fn
dst_tensor = torch.empty(tensor.shape, device=device, dtype=fp8_dtype)
scale = torch.zeros(1, device=device, dtype=torch.float32)
else:
QK = ggml.ggml_qk_size(qtype)
block_size_in_bytes = ggml.ggml_type_size(qtype)
@ -158,6 +165,17 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
enable_scale_search,
imatrix)
return dst_tensor, scale.type(torch.float16)
elif qtype in [TORCH_FP8E5, TORCH_FP8E4]:
import xe_linear
tensor_device = tensor.device
tensor_xpu = tensor.to("xpu")
dst_tensor = dst_tensor.to("xpu")
scale = scale.to("xpu")
xe_linear.dynamic_scaled_fp8_quant(dst_tensor, tensor_xpu, scale)
# scale = scale.to(tensor_device)
dst_tensor = dst_tensor.to(tensor_device)
else:
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
else:
@ -171,6 +189,8 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
hist, imatrix)
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
return dst_tensor, scale.type(torch.float16)
elif qtype in [TORCH_FP8E5, TORCH_FP8E4]:
return dst_tensor, scale
else:
return dst_tensor
@ -179,7 +199,7 @@ def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int
if qtype == NF4:
invalidInputError(tensor.dtype == torch.bfloat16,
"NF4 Input tensor must be bfloat16")
else:
elif qtype not in [TORCH_FP8E5, TORCH_FP8E4]:
invalidInputError(tensor.dtype == torch.uint8,
"Input tensor except NF4 must be uint8")
@ -208,7 +228,7 @@ def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int
if qtype == NF4:
invalidInputError(tensor.dtype == torch.bfloat16,
"NF4 Input tensor must be bfloat16")
else:
elif qtype not in [TORCH_FP8E5, TORCH_FP8E4]:
invalidInputError(tensor.dtype == torch.uint8,
"Input tensor must be uint8")
@ -319,7 +339,8 @@ class FP4Params(torch.nn.Parameter):
qtype=None,
imatrix=None,
in_features=None,
enable_scale_search=False):
enable_scale_search=False,
torch_fp8_scale=None):
if data is None:
data = torch.empty(0)
@ -332,6 +353,7 @@ class FP4Params(torch.nn.Parameter):
self.imatrix = imatrix
self.in_features = in_features
self.enable_scale_search = enable_scale_search
self.torch_fp8_scale = torch_fp8_scale
return self
def ggml_mse(self, w, ggml_qtype, device):
@ -391,6 +413,10 @@ class FP4Params(torch.nn.Parameter):
imatrix=self.imatrix,
in_features=self.in_features,
enable_scale_search=self.enable_scale_search)
if self.qtype in [TORCH_FP8E5, TORCH_FP8E4]:
self.data = w_quantized[0]
self.torch_fp8_scale = w_quantized[1]
else:
self.data = w_quantized
self.quantized = True
self._shape = w.shape
@ -414,6 +440,8 @@ class FP4Params(torch.nn.Parameter):
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if self.qtype in [TORCH_FP8E5, TORCH_FP8E4]:
dtype = None
if (device is not None and device.type == "cpu" and self.data.device.type == "cpu"):
return self.quantize(device.type)
elif device is not None and device.type == "meta" and self.data.device.type == "meta":
@ -424,6 +452,7 @@ class FP4Params(torch.nn.Parameter):
self.data = ggml_q_format_convet_cpu2xpu(self.data,
reduce(mul, self._shape, 1),
self.qtype)
fp8_scale = None if self.torch_fp8_scale is None else self.torch_fp8_scale.to(device)
new_param = FP4Params(super().to(device=device,
dtype=dtype,
non_blocking=non_blocking),
@ -431,9 +460,11 @@ class FP4Params(torch.nn.Parameter):
quantized=self.quantized,
_shape=self._shape,
qtype=self.qtype,
enable_scale_search=self.enable_scale_search)
enable_scale_search=self.enable_scale_search,
torch_fp8_scale=fp8_scale)
return new_param
elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"):
fp8_scale = None if self.torch_fp8_scale is None else self.torch_fp8_scale.to(device)
new_param = FP4Params(super().to(device=device,
dtype=dtype,
non_blocking=non_blocking),
@ -441,7 +472,8 @@ class FP4Params(torch.nn.Parameter):
quantized=self.quantized,
_shape=self._shape,
qtype=self.qtype,
enable_scale_search=self.enable_scale_search)
enable_scale_search=self.enable_scale_search,
torch_fp8_scale=fp8_scale)
ggml_xpu = new_param.data
new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu,
reduce(mul, new_param._shape, 1),
@ -614,6 +646,7 @@ class LowBitLinear(nn.Linear):
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
# we should check both self.training and torch.is_inference_mode_enabled().
is_training = self.training and not torch.is_inference_mode_enabled()
if is_training:
# below logic is only for training
autocast_dtype = get_autocast_dtype(x.device.type)
@ -643,6 +676,8 @@ class LowBitLinear(nn.Linear):
if self.weight.device.type == "xpu":
if is_training and x_2d.requires_grad:
invalidInputError(self.weight.qtype not in [TORCH_FP8E5, TORCH_FP8E4],
"TORCH_FP8 training is not supported.")
result = MatMulLowBit.apply(x_2d, self.weight, self.out_len)
else:
do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024
@ -654,7 +689,11 @@ class LowBitLinear(nn.Linear):
else:
w = self.weight.data
if use_batch_forward(x_2d, self.weight.qtype, self.out_len) and \
if self.weight.qtype in [TORCH_FP8E5, TORCH_FP8E4]:
import xe_linear
result = xe_linear.run_linear_fp8(x_2d, w, self.bias,
self.weight.torch_fp8_scale)
elif use_batch_forward(x_2d, self.weight.qtype, self.out_len) and \
(x_2d.dtype == torch.half or self.conver_to_half):
import xe_batch
result = xe_batch.batch_forward(x_2d, w, self.qtype)
@ -682,13 +721,13 @@ class LowBitLinear(nn.Linear):
else:
invalidInputError(False, "mp_group is not None, but no supported backend found")
if self.bias is not None:
if self.bias is not None and self.weight.qtype not in [TORCH_FP8E5, TORCH_FP8E4]:
result += self.bias
else:
# CPU logic
# todo may need to set a different number on different platforms
invalidInputError(self.qtype != NF3 and self.qtype != NF4 and self.qtype != FP8E4
and self.qtype != FP4 and self.qtype != FP8E5,
invalidInputError(self.qtype not in [NF3, NF4, FP8E4, FP4, FP8E5,
TORCH_FP8E5, TORCH_FP8E4],
"NF3, NF4, FP4 and FP8 quantization are currently not"
" supported on CPU")
if self.training and x.requires_grad: