diff --git a/python/llm/src/ipex_llm/ggml/quantize.py b/python/llm/src/ipex_llm/ggml/quantize.py index f86ee122..f7e4eade 100644 --- a/python/llm/src/ipex_llm/ggml/quantize.py +++ b/python/llm/src/ipex_llm/ggml/quantize.py @@ -54,6 +54,9 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml "sym_int8_rtn": 32, "asym_int4_rtn": 33, "woq_int4": 34, + "torch_fp8_e5m2": 35, + "torch_fp8": 35, + "torch_fp8_e4m3": 36 } # mixed precison from llama.cpp diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 13e27a05..802ef370 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -86,6 +86,8 @@ SYM_INT4_RTN = ggml_tensor_qtype["sym_int4_rtn"] SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"] ASYM_INT4_RTN = ggml_tensor_qtype["asym_int4_rtn"] WOQ_INT4 = ggml_tensor_qtype["woq_int4"] +TORCH_FP8E5 = ggml_tensor_qtype["torch_fp8_e5m2"] +TORCH_FP8E4 = ggml_tensor_qtype["torch_fp8_e4m3"] RTN_DTYPE = { SYM_INT4_RTN: torch.uint8, ASYM_INT4_RTN: torch.uint8, @@ -106,39 +108,44 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, imatrix: torch.Tensor=None, in_features: int=None, enable_scale_search: bool=False): - QK = ggml.ggml_qk_size(qtype) - block_size_in_bytes = ggml.ggml_type_size(qtype) - - invalidInputError(tensor.dtype == torch.float, - "Input tensor must be float32") - src = tensor.data.data_ptr() - src = ctypes.cast(src, ctypes.POINTER(ctypes.c_float)) - n = tensor.numel() # all elements - k = tensor.shape[-1] - invalidInputError(k % QK == 0, - f"Last dim of input tensor must be multiple of {QK}") - - dst_size = (n // QK) * block_size_in_bytes - if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]: - dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype], - device=device) - dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK) - if qtype == ASYM_INT4_RTN: - scale = torch.empty((n // k) * 2, dtype=torch.float32, - device=device) - else: - scale = torch.empty(n // k, dtype=torch.float32, - device=device) - elif qtype == NF4: - # Deepspeed zero3 requires unified dtype, - # thus here uses bfloat16 consistent to other layers - # dst_size above is computed based on uint8, and for bfloat16, - # buffer size should be half - dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16, - device=device) + if qtype in [TORCH_FP8E5, TORCH_FP8E4]: + fp8_dtype = torch.float8_e5m2 if qtype == TORCH_FP8E5 else torch.float8_e4m3fn + dst_tensor = torch.empty(tensor.shape, device=device, dtype=fp8_dtype) + scale = torch.zeros(1, device=device, dtype=torch.float32) else: - dst_tensor = torch.empty(dst_size, dtype=torch.uint8, - device=device) + QK = ggml.ggml_qk_size(qtype) + block_size_in_bytes = ggml.ggml_type_size(qtype) + + invalidInputError(tensor.dtype == torch.float, + "Input tensor must be float32") + src = tensor.data.data_ptr() + src = ctypes.cast(src, ctypes.POINTER(ctypes.c_float)) + n = tensor.numel() # all elements + k = tensor.shape[-1] + invalidInputError(k % QK == 0, + f"Last dim of input tensor must be multiple of {QK}") + + dst_size = (n // QK) * block_size_in_bytes + if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]: + dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype], + device=device) + dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK) + if qtype == ASYM_INT4_RTN: + scale = torch.empty((n // k) * 2, dtype=torch.float32, + device=device) + else: + scale = torch.empty(n // k, dtype=torch.float32, + device=device) + elif qtype == NF4: + # Deepspeed zero3 requires unified dtype, + # thus here uses bfloat16 consistent to other layers + # dst_size above is computed based on uint8, and for bfloat16, + # buffer size should be half + dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16, + device=device) + else: + dst_tensor = torch.empty(dst_size, dtype=torch.uint8, + device=device) if not convert_shape_only and device != 'meta': dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) @@ -158,6 +165,17 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, enable_scale_search, imatrix) return dst_tensor, scale.type(torch.float16) + elif qtype in [TORCH_FP8E5, TORCH_FP8E4]: + import xe_linear + tensor_device = tensor.device + tensor_xpu = tensor.to("xpu") + dst_tensor = dst_tensor.to("xpu") + scale = scale.to("xpu") + + xe_linear.dynamic_scaled_fp8_quant(dst_tensor, tensor_xpu, scale) + + # scale = scale.to(tensor_device) + dst_tensor = dst_tensor.to(tensor_device) else: ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search) else: @@ -171,6 +189,8 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, hist, imatrix) if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]: return dst_tensor, scale.type(torch.float16) + elif qtype in [TORCH_FP8E5, TORCH_FP8E4]: + return dst_tensor, scale else: return dst_tensor @@ -179,7 +199,7 @@ def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int if qtype == NF4: invalidInputError(tensor.dtype == torch.bfloat16, "NF4 Input tensor must be bfloat16") - else: + elif qtype not in [TORCH_FP8E5, TORCH_FP8E4]: invalidInputError(tensor.dtype == torch.uint8, "Input tensor except NF4 must be uint8") @@ -208,7 +228,7 @@ def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int if qtype == NF4: invalidInputError(tensor.dtype == torch.bfloat16, "NF4 Input tensor must be bfloat16") - else: + elif qtype not in [TORCH_FP8E5, TORCH_FP8E4]: invalidInputError(tensor.dtype == torch.uint8, "Input tensor must be uint8") @@ -319,7 +339,8 @@ class FP4Params(torch.nn.Parameter): qtype=None, imatrix=None, in_features=None, - enable_scale_search=False): + enable_scale_search=False, + torch_fp8_scale=None): if data is None: data = torch.empty(0) @@ -332,6 +353,7 @@ class FP4Params(torch.nn.Parameter): self.imatrix = imatrix self.in_features = in_features self.enable_scale_search = enable_scale_search + self.torch_fp8_scale = torch_fp8_scale return self def ggml_mse(self, w, ggml_qtype, device): @@ -391,7 +413,11 @@ class FP4Params(torch.nn.Parameter): imatrix=self.imatrix, in_features=self.in_features, enable_scale_search=self.enable_scale_search) - self.data = w_quantized + if self.qtype in [TORCH_FP8E5, TORCH_FP8E4]: + self.data = w_quantized[0] + self.torch_fp8_scale = w_quantized[1] + else: + self.data = w_quantized self.quantized = True self._shape = w.shape return self @@ -414,6 +440,8 @@ class FP4Params(torch.nn.Parameter): def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + if self.qtype in [TORCH_FP8E5, TORCH_FP8E4]: + dtype = None if (device is not None and device.type == "cpu" and self.data.device.type == "cpu"): return self.quantize(device.type) elif device is not None and device.type == "meta" and self.data.device.type == "meta": @@ -424,6 +452,7 @@ class FP4Params(torch.nn.Parameter): self.data = ggml_q_format_convet_cpu2xpu(self.data, reduce(mul, self._shape, 1), self.qtype) + fp8_scale = None if self.torch_fp8_scale is None else self.torch_fp8_scale.to(device) new_param = FP4Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), @@ -431,9 +460,11 @@ class FP4Params(torch.nn.Parameter): quantized=self.quantized, _shape=self._shape, qtype=self.qtype, - enable_scale_search=self.enable_scale_search) + enable_scale_search=self.enable_scale_search, + torch_fp8_scale=fp8_scale) return new_param elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"): + fp8_scale = None if self.torch_fp8_scale is None else self.torch_fp8_scale.to(device) new_param = FP4Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), @@ -441,7 +472,8 @@ class FP4Params(torch.nn.Parameter): quantized=self.quantized, _shape=self._shape, qtype=self.qtype, - enable_scale_search=self.enable_scale_search) + enable_scale_search=self.enable_scale_search, + torch_fp8_scale=fp8_scale) ggml_xpu = new_param.data new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu, reduce(mul, new_param._shape, 1), @@ -614,6 +646,7 @@ class LowBitLinear(nn.Linear): # Due to inconsistent training status in some models like Baichuan-7b-Chat, # we should check both self.training and torch.is_inference_mode_enabled(). is_training = self.training and not torch.is_inference_mode_enabled() + if is_training: # below logic is only for training autocast_dtype = get_autocast_dtype(x.device.type) @@ -643,6 +676,8 @@ class LowBitLinear(nn.Linear): if self.weight.device.type == "xpu": if is_training and x_2d.requires_grad: + invalidInputError(self.weight.qtype not in [TORCH_FP8E5, TORCH_FP8E4], + "TORCH_FP8 training is not supported.") result = MatMulLowBit.apply(x_2d, self.weight, self.out_len) else: do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024 @@ -654,7 +689,11 @@ class LowBitLinear(nn.Linear): else: w = self.weight.data - if use_batch_forward(x_2d, self.weight.qtype, self.out_len) and \ + if self.weight.qtype in [TORCH_FP8E5, TORCH_FP8E4]: + import xe_linear + result = xe_linear.run_linear_fp8(x_2d, w, self.bias, + self.weight.torch_fp8_scale) + elif use_batch_forward(x_2d, self.weight.qtype, self.out_len) and \ (x_2d.dtype == torch.half or self.conver_to_half): import xe_batch result = xe_batch.batch_forward(x_2d, w, self.qtype) @@ -682,13 +721,13 @@ class LowBitLinear(nn.Linear): else: invalidInputError(False, "mp_group is not None, but no supported backend found") - if self.bias is not None: + if self.bias is not None and self.weight.qtype not in [TORCH_FP8E5, TORCH_FP8E4]: result += self.bias else: # CPU logic # todo may need to set a different number on different platforms - invalidInputError(self.qtype != NF3 and self.qtype != NF4 and self.qtype != FP8E4 - and self.qtype != FP4 and self.qtype != FP8E5, + invalidInputError(self.qtype not in [NF3, NF4, FP8E4, FP4, FP8E5, + TORCH_FP8E5, TORCH_FP8E4], "NF3, NF4, FP4 and FP8 quantization are currently not" " supported on CPU") if self.training and x.requires_grad: