parent
3accc31b86
commit
e032156518
2 changed files with 84 additions and 42 deletions
|
|
@ -54,6 +54,9 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
|
|||
"sym_int8_rtn": 32,
|
||||
"asym_int4_rtn": 33,
|
||||
"woq_int4": 34,
|
||||
"torch_fp8_e5m2": 35,
|
||||
"torch_fp8": 35,
|
||||
"torch_fp8_e4m3": 36
|
||||
}
|
||||
|
||||
# mixed precison from llama.cpp
|
||||
|
|
|
|||
|
|
@ -86,6 +86,8 @@ SYM_INT4_RTN = ggml_tensor_qtype["sym_int4_rtn"]
|
|||
SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"]
|
||||
ASYM_INT4_RTN = ggml_tensor_qtype["asym_int4_rtn"]
|
||||
WOQ_INT4 = ggml_tensor_qtype["woq_int4"]
|
||||
TORCH_FP8E5 = ggml_tensor_qtype["torch_fp8_e5m2"]
|
||||
TORCH_FP8E4 = ggml_tensor_qtype["torch_fp8_e4m3"]
|
||||
RTN_DTYPE = {
|
||||
SYM_INT4_RTN: torch.uint8,
|
||||
ASYM_INT4_RTN: torch.uint8,
|
||||
|
|
@ -106,39 +108,44 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
|||
imatrix: torch.Tensor=None,
|
||||
in_features: int=None,
|
||||
enable_scale_search: bool=False):
|
||||
QK = ggml.ggml_qk_size(qtype)
|
||||
block_size_in_bytes = ggml.ggml_type_size(qtype)
|
||||
|
||||
invalidInputError(tensor.dtype == torch.float,
|
||||
"Input tensor must be float32")
|
||||
src = tensor.data.data_ptr()
|
||||
src = ctypes.cast(src, ctypes.POINTER(ctypes.c_float))
|
||||
n = tensor.numel() # all elements
|
||||
k = tensor.shape[-1]
|
||||
invalidInputError(k % QK == 0,
|
||||
f"Last dim of input tensor must be multiple of {QK}")
|
||||
|
||||
dst_size = (n // QK) * block_size_in_bytes
|
||||
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
|
||||
dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype],
|
||||
device=device)
|
||||
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
|
||||
if qtype == ASYM_INT4_RTN:
|
||||
scale = torch.empty((n // k) * 2, dtype=torch.float32,
|
||||
device=device)
|
||||
else:
|
||||
scale = torch.empty(n // k, dtype=torch.float32,
|
||||
device=device)
|
||||
elif qtype == NF4:
|
||||
# Deepspeed zero3 requires unified dtype,
|
||||
# thus here uses bfloat16 consistent to other layers
|
||||
# dst_size above is computed based on uint8, and for bfloat16,
|
||||
# buffer size should be half
|
||||
dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16,
|
||||
device=device)
|
||||
if qtype in [TORCH_FP8E5, TORCH_FP8E4]:
|
||||
fp8_dtype = torch.float8_e5m2 if qtype == TORCH_FP8E5 else torch.float8_e4m3fn
|
||||
dst_tensor = torch.empty(tensor.shape, device=device, dtype=fp8_dtype)
|
||||
scale = torch.zeros(1, device=device, dtype=torch.float32)
|
||||
else:
|
||||
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
|
||||
device=device)
|
||||
QK = ggml.ggml_qk_size(qtype)
|
||||
block_size_in_bytes = ggml.ggml_type_size(qtype)
|
||||
|
||||
invalidInputError(tensor.dtype == torch.float,
|
||||
"Input tensor must be float32")
|
||||
src = tensor.data.data_ptr()
|
||||
src = ctypes.cast(src, ctypes.POINTER(ctypes.c_float))
|
||||
n = tensor.numel() # all elements
|
||||
k = tensor.shape[-1]
|
||||
invalidInputError(k % QK == 0,
|
||||
f"Last dim of input tensor must be multiple of {QK}")
|
||||
|
||||
dst_size = (n // QK) * block_size_in_bytes
|
||||
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
|
||||
dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype],
|
||||
device=device)
|
||||
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
|
||||
if qtype == ASYM_INT4_RTN:
|
||||
scale = torch.empty((n // k) * 2, dtype=torch.float32,
|
||||
device=device)
|
||||
else:
|
||||
scale = torch.empty(n // k, dtype=torch.float32,
|
||||
device=device)
|
||||
elif qtype == NF4:
|
||||
# Deepspeed zero3 requires unified dtype,
|
||||
# thus here uses bfloat16 consistent to other layers
|
||||
# dst_size above is computed based on uint8, and for bfloat16,
|
||||
# buffer size should be half
|
||||
dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16,
|
||||
device=device)
|
||||
else:
|
||||
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
|
||||
device=device)
|
||||
|
||||
if not convert_shape_only and device != 'meta':
|
||||
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
||||
|
|
@ -158,6 +165,17 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
|||
enable_scale_search,
|
||||
imatrix)
|
||||
return dst_tensor, scale.type(torch.float16)
|
||||
elif qtype in [TORCH_FP8E5, TORCH_FP8E4]:
|
||||
import xe_linear
|
||||
tensor_device = tensor.device
|
||||
tensor_xpu = tensor.to("xpu")
|
||||
dst_tensor = dst_tensor.to("xpu")
|
||||
scale = scale.to("xpu")
|
||||
|
||||
xe_linear.dynamic_scaled_fp8_quant(dst_tensor, tensor_xpu, scale)
|
||||
|
||||
# scale = scale.to(tensor_device)
|
||||
dst_tensor = dst_tensor.to(tensor_device)
|
||||
else:
|
||||
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
|
||||
else:
|
||||
|
|
@ -171,6 +189,8 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
|||
hist, imatrix)
|
||||
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
|
||||
return dst_tensor, scale.type(torch.float16)
|
||||
elif qtype in [TORCH_FP8E5, TORCH_FP8E4]:
|
||||
return dst_tensor, scale
|
||||
else:
|
||||
return dst_tensor
|
||||
|
||||
|
|
@ -179,7 +199,7 @@ def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int
|
|||
if qtype == NF4:
|
||||
invalidInputError(tensor.dtype == torch.bfloat16,
|
||||
"NF4 Input tensor must be bfloat16")
|
||||
else:
|
||||
elif qtype not in [TORCH_FP8E5, TORCH_FP8E4]:
|
||||
invalidInputError(tensor.dtype == torch.uint8,
|
||||
"Input tensor except NF4 must be uint8")
|
||||
|
||||
|
|
@ -208,7 +228,7 @@ def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int
|
|||
if qtype == NF4:
|
||||
invalidInputError(tensor.dtype == torch.bfloat16,
|
||||
"NF4 Input tensor must be bfloat16")
|
||||
else:
|
||||
elif qtype not in [TORCH_FP8E5, TORCH_FP8E4]:
|
||||
invalidInputError(tensor.dtype == torch.uint8,
|
||||
"Input tensor must be uint8")
|
||||
|
||||
|
|
@ -319,7 +339,8 @@ class FP4Params(torch.nn.Parameter):
|
|||
qtype=None,
|
||||
imatrix=None,
|
||||
in_features=None,
|
||||
enable_scale_search=False):
|
||||
enable_scale_search=False,
|
||||
torch_fp8_scale=None):
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
|
||||
|
|
@ -332,6 +353,7 @@ class FP4Params(torch.nn.Parameter):
|
|||
self.imatrix = imatrix
|
||||
self.in_features = in_features
|
||||
self.enable_scale_search = enable_scale_search
|
||||
self.torch_fp8_scale = torch_fp8_scale
|
||||
return self
|
||||
|
||||
def ggml_mse(self, w, ggml_qtype, device):
|
||||
|
|
@ -391,7 +413,11 @@ class FP4Params(torch.nn.Parameter):
|
|||
imatrix=self.imatrix,
|
||||
in_features=self.in_features,
|
||||
enable_scale_search=self.enable_scale_search)
|
||||
self.data = w_quantized
|
||||
if self.qtype in [TORCH_FP8E5, TORCH_FP8E4]:
|
||||
self.data = w_quantized[0]
|
||||
self.torch_fp8_scale = w_quantized[1]
|
||||
else:
|
||||
self.data = w_quantized
|
||||
self.quantized = True
|
||||
self._shape = w.shape
|
||||
return self
|
||||
|
|
@ -414,6 +440,8 @@ class FP4Params(torch.nn.Parameter):
|
|||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if self.qtype in [TORCH_FP8E5, TORCH_FP8E4]:
|
||||
dtype = None
|
||||
if (device is not None and device.type == "cpu" and self.data.device.type == "cpu"):
|
||||
return self.quantize(device.type)
|
||||
elif device is not None and device.type == "meta" and self.data.device.type == "meta":
|
||||
|
|
@ -424,6 +452,7 @@ class FP4Params(torch.nn.Parameter):
|
|||
self.data = ggml_q_format_convet_cpu2xpu(self.data,
|
||||
reduce(mul, self._shape, 1),
|
||||
self.qtype)
|
||||
fp8_scale = None if self.torch_fp8_scale is None else self.torch_fp8_scale.to(device)
|
||||
new_param = FP4Params(super().to(device=device,
|
||||
dtype=dtype,
|
||||
non_blocking=non_blocking),
|
||||
|
|
@ -431,9 +460,11 @@ class FP4Params(torch.nn.Parameter):
|
|||
quantized=self.quantized,
|
||||
_shape=self._shape,
|
||||
qtype=self.qtype,
|
||||
enable_scale_search=self.enable_scale_search)
|
||||
enable_scale_search=self.enable_scale_search,
|
||||
torch_fp8_scale=fp8_scale)
|
||||
return new_param
|
||||
elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"):
|
||||
fp8_scale = None if self.torch_fp8_scale is None else self.torch_fp8_scale.to(device)
|
||||
new_param = FP4Params(super().to(device=device,
|
||||
dtype=dtype,
|
||||
non_blocking=non_blocking),
|
||||
|
|
@ -441,7 +472,8 @@ class FP4Params(torch.nn.Parameter):
|
|||
quantized=self.quantized,
|
||||
_shape=self._shape,
|
||||
qtype=self.qtype,
|
||||
enable_scale_search=self.enable_scale_search)
|
||||
enable_scale_search=self.enable_scale_search,
|
||||
torch_fp8_scale=fp8_scale)
|
||||
ggml_xpu = new_param.data
|
||||
new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu,
|
||||
reduce(mul, new_param._shape, 1),
|
||||
|
|
@ -614,6 +646,7 @@ class LowBitLinear(nn.Linear):
|
|||
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
|
||||
# we should check both self.training and torch.is_inference_mode_enabled().
|
||||
is_training = self.training and not torch.is_inference_mode_enabled()
|
||||
|
||||
if is_training:
|
||||
# below logic is only for training
|
||||
autocast_dtype = get_autocast_dtype(x.device.type)
|
||||
|
|
@ -643,6 +676,8 @@ class LowBitLinear(nn.Linear):
|
|||
|
||||
if self.weight.device.type == "xpu":
|
||||
if is_training and x_2d.requires_grad:
|
||||
invalidInputError(self.weight.qtype not in [TORCH_FP8E5, TORCH_FP8E4],
|
||||
"TORCH_FP8 training is not supported.")
|
||||
result = MatMulLowBit.apply(x_2d, self.weight, self.out_len)
|
||||
else:
|
||||
do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024
|
||||
|
|
@ -654,7 +689,11 @@ class LowBitLinear(nn.Linear):
|
|||
else:
|
||||
w = self.weight.data
|
||||
|
||||
if use_batch_forward(x_2d, self.weight.qtype, self.out_len) and \
|
||||
if self.weight.qtype in [TORCH_FP8E5, TORCH_FP8E4]:
|
||||
import xe_linear
|
||||
result = xe_linear.run_linear_fp8(x_2d, w, self.bias,
|
||||
self.weight.torch_fp8_scale)
|
||||
elif use_batch_forward(x_2d, self.weight.qtype, self.out_len) and \
|
||||
(x_2d.dtype == torch.half or self.conver_to_half):
|
||||
import xe_batch
|
||||
result = xe_batch.batch_forward(x_2d, w, self.qtype)
|
||||
|
|
@ -682,13 +721,13 @@ class LowBitLinear(nn.Linear):
|
|||
else:
|
||||
invalidInputError(False, "mp_group is not None, but no supported backend found")
|
||||
|
||||
if self.bias is not None:
|
||||
if self.bias is not None and self.weight.qtype not in [TORCH_FP8E5, TORCH_FP8E4]:
|
||||
result += self.bias
|
||||
else:
|
||||
# CPU logic
|
||||
# todo may need to set a different number on different platforms
|
||||
invalidInputError(self.qtype != NF3 and self.qtype != NF4 and self.qtype != FP8E4
|
||||
and self.qtype != FP4 and self.qtype != FP8E5,
|
||||
invalidInputError(self.qtype not in [NF3, NF4, FP8E4, FP4, FP8E5,
|
||||
TORCH_FP8E5, TORCH_FP8E4],
|
||||
"NF3, NF4, FP4 and FP8 quantization are currently not"
|
||||
" supported on CPU")
|
||||
if self.training and x.requires_grad:
|
||||
|
|
|
|||
Loading…
Reference in a new issue