Support torch_fp8 (#13196)

* support torch_fp8
This commit is contained in:
Yina Chen 2025-06-04 20:08:01 +08:00 committed by GitHub
parent 3accc31b86
commit e032156518
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 84 additions and 42 deletions

View file

@ -54,6 +54,9 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
"sym_int8_rtn": 32, "sym_int8_rtn": 32,
"asym_int4_rtn": 33, "asym_int4_rtn": 33,
"woq_int4": 34, "woq_int4": 34,
"torch_fp8_e5m2": 35,
"torch_fp8": 35,
"torch_fp8_e4m3": 36
} }
# mixed precison from llama.cpp # mixed precison from llama.cpp

View file

@ -86,6 +86,8 @@ SYM_INT4_RTN = ggml_tensor_qtype["sym_int4_rtn"]
SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"] SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"]
ASYM_INT4_RTN = ggml_tensor_qtype["asym_int4_rtn"] ASYM_INT4_RTN = ggml_tensor_qtype["asym_int4_rtn"]
WOQ_INT4 = ggml_tensor_qtype["woq_int4"] WOQ_INT4 = ggml_tensor_qtype["woq_int4"]
TORCH_FP8E5 = ggml_tensor_qtype["torch_fp8_e5m2"]
TORCH_FP8E4 = ggml_tensor_qtype["torch_fp8_e4m3"]
RTN_DTYPE = { RTN_DTYPE = {
SYM_INT4_RTN: torch.uint8, SYM_INT4_RTN: torch.uint8,
ASYM_INT4_RTN: torch.uint8, ASYM_INT4_RTN: torch.uint8,
@ -106,6 +108,11 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
imatrix: torch.Tensor=None, imatrix: torch.Tensor=None,
in_features: int=None, in_features: int=None,
enable_scale_search: bool=False): enable_scale_search: bool=False):
if qtype in [TORCH_FP8E5, TORCH_FP8E4]:
fp8_dtype = torch.float8_e5m2 if qtype == TORCH_FP8E5 else torch.float8_e4m3fn
dst_tensor = torch.empty(tensor.shape, device=device, dtype=fp8_dtype)
scale = torch.zeros(1, device=device, dtype=torch.float32)
else:
QK = ggml.ggml_qk_size(qtype) QK = ggml.ggml_qk_size(qtype)
block_size_in_bytes = ggml.ggml_type_size(qtype) block_size_in_bytes = ggml.ggml_type_size(qtype)
@ -158,6 +165,17 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
enable_scale_search, enable_scale_search,
imatrix) imatrix)
return dst_tensor, scale.type(torch.float16) return dst_tensor, scale.type(torch.float16)
elif qtype in [TORCH_FP8E5, TORCH_FP8E4]:
import xe_linear
tensor_device = tensor.device
tensor_xpu = tensor.to("xpu")
dst_tensor = dst_tensor.to("xpu")
scale = scale.to("xpu")
xe_linear.dynamic_scaled_fp8_quant(dst_tensor, tensor_xpu, scale)
# scale = scale.to(tensor_device)
dst_tensor = dst_tensor.to(tensor_device)
else: else:
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search) ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
else: else:
@ -171,6 +189,8 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
hist, imatrix) hist, imatrix)
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]: if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
return dst_tensor, scale.type(torch.float16) return dst_tensor, scale.type(torch.float16)
elif qtype in [TORCH_FP8E5, TORCH_FP8E4]:
return dst_tensor, scale
else: else:
return dst_tensor return dst_tensor
@ -179,7 +199,7 @@ def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int
if qtype == NF4: if qtype == NF4:
invalidInputError(tensor.dtype == torch.bfloat16, invalidInputError(tensor.dtype == torch.bfloat16,
"NF4 Input tensor must be bfloat16") "NF4 Input tensor must be bfloat16")
else: elif qtype not in [TORCH_FP8E5, TORCH_FP8E4]:
invalidInputError(tensor.dtype == torch.uint8, invalidInputError(tensor.dtype == torch.uint8,
"Input tensor except NF4 must be uint8") "Input tensor except NF4 must be uint8")
@ -208,7 +228,7 @@ def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int
if qtype == NF4: if qtype == NF4:
invalidInputError(tensor.dtype == torch.bfloat16, invalidInputError(tensor.dtype == torch.bfloat16,
"NF4 Input tensor must be bfloat16") "NF4 Input tensor must be bfloat16")
else: elif qtype not in [TORCH_FP8E5, TORCH_FP8E4]:
invalidInputError(tensor.dtype == torch.uint8, invalidInputError(tensor.dtype == torch.uint8,
"Input tensor must be uint8") "Input tensor must be uint8")
@ -319,7 +339,8 @@ class FP4Params(torch.nn.Parameter):
qtype=None, qtype=None,
imatrix=None, imatrix=None,
in_features=None, in_features=None,
enable_scale_search=False): enable_scale_search=False,
torch_fp8_scale=None):
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
@ -332,6 +353,7 @@ class FP4Params(torch.nn.Parameter):
self.imatrix = imatrix self.imatrix = imatrix
self.in_features = in_features self.in_features = in_features
self.enable_scale_search = enable_scale_search self.enable_scale_search = enable_scale_search
self.torch_fp8_scale = torch_fp8_scale
return self return self
def ggml_mse(self, w, ggml_qtype, device): def ggml_mse(self, w, ggml_qtype, device):
@ -391,6 +413,10 @@ class FP4Params(torch.nn.Parameter):
imatrix=self.imatrix, imatrix=self.imatrix,
in_features=self.in_features, in_features=self.in_features,
enable_scale_search=self.enable_scale_search) enable_scale_search=self.enable_scale_search)
if self.qtype in [TORCH_FP8E5, TORCH_FP8E4]:
self.data = w_quantized[0]
self.torch_fp8_scale = w_quantized[1]
else:
self.data = w_quantized self.data = w_quantized
self.quantized = True self.quantized = True
self._shape = w.shape self._shape = w.shape
@ -414,6 +440,8 @@ class FP4Params(torch.nn.Parameter):
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if self.qtype in [TORCH_FP8E5, TORCH_FP8E4]:
dtype = None
if (device is not None and device.type == "cpu" and self.data.device.type == "cpu"): if (device is not None and device.type == "cpu" and self.data.device.type == "cpu"):
return self.quantize(device.type) return self.quantize(device.type)
elif device is not None and device.type == "meta" and self.data.device.type == "meta": elif device is not None and device.type == "meta" and self.data.device.type == "meta":
@ -424,6 +452,7 @@ class FP4Params(torch.nn.Parameter):
self.data = ggml_q_format_convet_cpu2xpu(self.data, self.data = ggml_q_format_convet_cpu2xpu(self.data,
reduce(mul, self._shape, 1), reduce(mul, self._shape, 1),
self.qtype) self.qtype)
fp8_scale = None if self.torch_fp8_scale is None else self.torch_fp8_scale.to(device)
new_param = FP4Params(super().to(device=device, new_param = FP4Params(super().to(device=device,
dtype=dtype, dtype=dtype,
non_blocking=non_blocking), non_blocking=non_blocking),
@ -431,9 +460,11 @@ class FP4Params(torch.nn.Parameter):
quantized=self.quantized, quantized=self.quantized,
_shape=self._shape, _shape=self._shape,
qtype=self.qtype, qtype=self.qtype,
enable_scale_search=self.enable_scale_search) enable_scale_search=self.enable_scale_search,
torch_fp8_scale=fp8_scale)
return new_param return new_param
elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"): elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"):
fp8_scale = None if self.torch_fp8_scale is None else self.torch_fp8_scale.to(device)
new_param = FP4Params(super().to(device=device, new_param = FP4Params(super().to(device=device,
dtype=dtype, dtype=dtype,
non_blocking=non_blocking), non_blocking=non_blocking),
@ -441,7 +472,8 @@ class FP4Params(torch.nn.Parameter):
quantized=self.quantized, quantized=self.quantized,
_shape=self._shape, _shape=self._shape,
qtype=self.qtype, qtype=self.qtype,
enable_scale_search=self.enable_scale_search) enable_scale_search=self.enable_scale_search,
torch_fp8_scale=fp8_scale)
ggml_xpu = new_param.data ggml_xpu = new_param.data
new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu, new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu,
reduce(mul, new_param._shape, 1), reduce(mul, new_param._shape, 1),
@ -614,6 +646,7 @@ class LowBitLinear(nn.Linear):
# Due to inconsistent training status in some models like Baichuan-7b-Chat, # Due to inconsistent training status in some models like Baichuan-7b-Chat,
# we should check both self.training and torch.is_inference_mode_enabled(). # we should check both self.training and torch.is_inference_mode_enabled().
is_training = self.training and not torch.is_inference_mode_enabled() is_training = self.training and not torch.is_inference_mode_enabled()
if is_training: if is_training:
# below logic is only for training # below logic is only for training
autocast_dtype = get_autocast_dtype(x.device.type) autocast_dtype = get_autocast_dtype(x.device.type)
@ -643,6 +676,8 @@ class LowBitLinear(nn.Linear):
if self.weight.device.type == "xpu": if self.weight.device.type == "xpu":
if is_training and x_2d.requires_grad: if is_training and x_2d.requires_grad:
invalidInputError(self.weight.qtype not in [TORCH_FP8E5, TORCH_FP8E4],
"TORCH_FP8 training is not supported.")
result = MatMulLowBit.apply(x_2d, self.weight, self.out_len) result = MatMulLowBit.apply(x_2d, self.weight, self.out_len)
else: else:
do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024 do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024
@ -654,7 +689,11 @@ class LowBitLinear(nn.Linear):
else: else:
w = self.weight.data w = self.weight.data
if use_batch_forward(x_2d, self.weight.qtype, self.out_len) and \ if self.weight.qtype in [TORCH_FP8E5, TORCH_FP8E4]:
import xe_linear
result = xe_linear.run_linear_fp8(x_2d, w, self.bias,
self.weight.torch_fp8_scale)
elif use_batch_forward(x_2d, self.weight.qtype, self.out_len) and \
(x_2d.dtype == torch.half or self.conver_to_half): (x_2d.dtype == torch.half or self.conver_to_half):
import xe_batch import xe_batch
result = xe_batch.batch_forward(x_2d, w, self.qtype) result = xe_batch.batch_forward(x_2d, w, self.qtype)
@ -682,13 +721,13 @@ class LowBitLinear(nn.Linear):
else: else:
invalidInputError(False, "mp_group is not None, but no supported backend found") invalidInputError(False, "mp_group is not None, but no supported backend found")
if self.bias is not None: if self.bias is not None and self.weight.qtype not in [TORCH_FP8E5, TORCH_FP8E4]:
result += self.bias result += self.bias
else: else:
# CPU logic # CPU logic
# todo may need to set a different number on different platforms # todo may need to set a different number on different platforms
invalidInputError(self.qtype != NF3 and self.qtype != NF4 and self.qtype != FP8E4 invalidInputError(self.qtype not in [NF3, NF4, FP8E4, FP4, FP8E5,
and self.qtype != FP4 and self.qtype != FP8E5, TORCH_FP8E5, TORCH_FP8E4],
"NF3, NF4, FP4 and FP8 quantization are currently not" "NF3, NF4, FP4 and FP8 quantization are currently not"
" supported on CPU") " supported on CPU")
if self.training and x.requires_grad: if self.training and x.requires_grad: