Support MoFQ4 on arc (#9301)
* init * update * fix style * fix style * fix style * meet comments
This commit is contained in:
		
							parent
							
								
									8ef8e25178
								
							
						
					
					
						commit
						2262ae4d13
					
				
					 4 changed files with 68 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -999,6 +999,24 @@ _lib.ggml_dequantize_q4_0.argtypes = [
 | 
			
		|||
_lib.ggml_quantize_q4_0.restype = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ggml_dequantize(
 | 
			
		||||
    src: ctypes.c_void_p,
 | 
			
		||||
    dst: ctypes.c_void_p,
 | 
			
		||||
    k: ctypes.c_size_t,
 | 
			
		||||
    qtype: ctypes.c_int
 | 
			
		||||
):
 | 
			
		||||
    _lib.ggml_dequantize(src, dst, k, qtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_lib.ggml_dequantize.argtypes = [
 | 
			
		||||
    ctypes.c_void_p,
 | 
			
		||||
    ctypes.c_void_p,
 | 
			
		||||
    ctypes.c_size_t,
 | 
			
		||||
    ctypes.c_int
 | 
			
		||||
]
 | 
			
		||||
_lib.ggml_dequantize.restype = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ggml_q_format_convet_cpu2xpu(
 | 
			
		||||
    src: ctypes.c_void_p,
 | 
			
		||||
    dst: ctypes.c_void_p,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -34,7 +34,8 @@ ggml_tensor_qtype = {"sym_int4": 2,   # q4_0 in ggml
 | 
			
		|||
                     "nf3": 11,
 | 
			
		||||
                     "fp16": 12,
 | 
			
		||||
                     "fp8": 15,
 | 
			
		||||
                     "fp4": 16}
 | 
			
		||||
                     "fp4": 16,
 | 
			
		||||
                     "mixed_4bit": 17}     # Mixture of Formats Quantization 4 bits
 | 
			
		||||
 | 
			
		||||
_llama_quantize_type = {"q4_0": 2,
 | 
			
		||||
                        "q4_1": 3,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -66,6 +66,7 @@ NF4 = ggml_tensor_qtype["nf4"]
 | 
			
		|||
NF3 = ggml_tensor_qtype["nf3"]
 | 
			
		||||
FP8 = ggml_tensor_qtype["fp8"]
 | 
			
		||||
FP4 = ggml_tensor_qtype["fp4"]
 | 
			
		||||
MOFQ4 = ggml_tensor_qtype["mixed_4bit"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
 | 
			
		||||
| 
						 | 
				
			
			@ -158,6 +159,19 @@ def ggml_int4_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int):
 | 
			
		|||
    return dst_tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ggml_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int, qtype: int):
 | 
			
		||||
    invalidInputError(tensor.dtype == torch.uint8,
 | 
			
		||||
                      "Input tensor must be uint8")
 | 
			
		||||
    src_ptr = ctypes.c_void_p(tensor.data.data_ptr())
 | 
			
		||||
 | 
			
		||||
    dst_size = k
 | 
			
		||||
    dst_tensor = torch.empty(weight_shape, dtype=torch.float)
 | 
			
		||||
    dst_ptr = ctypes.c_void_p(dst_tensor.data.data_ptr())
 | 
			
		||||
 | 
			
		||||
    ggml.ggml_dequantize(src_ptr, dst_ptr, k, qtype)
 | 
			
		||||
    return dst_tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Rename to FP4Params to trigger initializing
 | 
			
		||||
# the params layer with all parameters on the CPU
 | 
			
		||||
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333
 | 
			
		||||
| 
						 | 
				
			
			@ -183,10 +197,39 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
    def quantize(self, device=None):
 | 
			
		||||
        if not self.quantized:
 | 
			
		||||
            w = self.data.contiguous().float()
 | 
			
		||||
            w_quantized = ggml_convert_qtype(w, self.qtype,
 | 
			
		||||
                                             device=device,
 | 
			
		||||
                                             convert_shape_only=self.convert_shape_only)
 | 
			
		||||
            self.data = w_quantized
 | 
			
		||||
            if self.qtype == MOFQ4:
 | 
			
		||||
                if device == 'meta':
 | 
			
		||||
                    w_quantized = ggml_convert_qtype(w, SYM_INT4,
 | 
			
		||||
                                                     device=device,
 | 
			
		||||
                                                     convert_shape_only=self.convert_shape_only)
 | 
			
		||||
                    # TODO: should load from config, the current implementation doesn't support
 | 
			
		||||
                    # save/load
 | 
			
		||||
                    self.qtype = SYM_INT4
 | 
			
		||||
                else:
 | 
			
		||||
                    from torch.nn.functional import mse_loss
 | 
			
		||||
                    w_quant_q4_0 = ggml_convert_qtype(w, SYM_INT4,
 | 
			
		||||
                                                      device=device,
 | 
			
		||||
                                                      convert_shape_only=self.convert_shape_only)
 | 
			
		||||
                    w_q4_0_dequant = ggml_convert_fp32(w_quant_q4_0, w.shape,
 | 
			
		||||
                                                       reduce(mul, w.shape, 1), SYM_INT4)
 | 
			
		||||
                    w_quant_fp4 = ggml_convert_qtype(w, FP4,
 | 
			
		||||
                                                     device=device,
 | 
			
		||||
                                                     convert_shape_only=self.convert_shape_only)
 | 
			
		||||
                    w_fp4_dequant = ggml_convert_fp32(w_quant_fp4, w.shape,
 | 
			
		||||
                                                      reduce(mul, w.shape, 1), FP4)
 | 
			
		||||
                    q4_0_mse = mse_loss(w_q4_0_dequant, w)
 | 
			
		||||
                    fp4_mse = mse_loss(w_fp4_dequant, w)
 | 
			
		||||
                    if q4_0_mse <= fp4_mse:
 | 
			
		||||
                        self.qtype = SYM_INT4
 | 
			
		||||
                        self.data = w_quant_q4_0
 | 
			
		||||
                    else:
 | 
			
		||||
                        self.qtype = FP4
 | 
			
		||||
                        self.data = w_quant_fp4
 | 
			
		||||
            else:
 | 
			
		||||
                w_quantized = ggml_convert_qtype(w, self.qtype,
 | 
			
		||||
                                                 device=device,
 | 
			
		||||
                                                 convert_shape_only=self.convert_shape_only)
 | 
			
		||||
                self.data = w_quantized
 | 
			
		||||
            self.quantized = True
 | 
			
		||||
            self._shape = w.shape
 | 
			
		||||
        return self
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -111,7 +111,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        invalidInputError(q_k in ggml_tensor_qtype,
 | 
			
		||||
                          f"Unknown load_in_low_bit value: {q_k}, expected:"
 | 
			
		||||
                          f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
 | 
			
		||||
                          "fp4, fp8 or fp16.")
 | 
			
		||||
                          "fp4, fp8, fp16 or mixed_4bit.")
 | 
			
		||||
        qtype = ggml_tensor_qtype[q_k]
 | 
			
		||||
 | 
			
		||||
        # In case it needs a second try,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue