[LLM] Support mixed_fp8 on Arc (#9415)
* ut gpu allocation memory fix * support mix_8bit on arc * rename mixed_4bit to mixed_fp4 and mixed_8bit to mixed_fp8 * revert unexpected changes * revert unexpected changes * unify common logits * rename in llm xmx_checker * fix typo error and re-unify
This commit is contained in:
		
							parent
							
								
									ac7fbe77e2
								
							
						
					
					
						commit
						2888818b3a
					
				
					 4 changed files with 36 additions and 17 deletions
				
			
		| 
						 | 
				
			
			@ -35,7 +35,8 @@ ggml_tensor_qtype = {"sym_int4": 2,   # q4_0 in ggml
 | 
			
		|||
                     "fp16": 12,
 | 
			
		||||
                     "fp8": 15,
 | 
			
		||||
                     "fp4": 16,
 | 
			
		||||
                     "mixed_4bit": 17}     # Mixture of Formats Quantization 4 bits
 | 
			
		||||
                     "mixed_fp4": 17,     # Mixture of Formats Quantization 4 bits
 | 
			
		||||
                     "mixed_fp8": 18}     # Mixture of Formats Quantization 8 bits
 | 
			
		||||
 | 
			
		||||
_llama_quantize_type = {"q4_0": 2,
 | 
			
		||||
                        "q4_1": 3,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -66,7 +66,8 @@ NF4 = ggml_tensor_qtype["nf4"]
 | 
			
		|||
NF3 = ggml_tensor_qtype["nf3"]
 | 
			
		||||
FP8 = ggml_tensor_qtype["fp8"]
 | 
			
		||||
FP4 = ggml_tensor_qtype["fp4"]
 | 
			
		||||
MOFQ4 = ggml_tensor_qtype["mixed_4bit"]
 | 
			
		||||
MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
 | 
			
		||||
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
 | 
			
		||||
| 
						 | 
				
			
			@ -194,6 +195,16 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
        self.convert_shape_only = convert_shape_only
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def ggml_mse(self, w, ggml_qtype, device):
 | 
			
		||||
        from torch.nn.functional import mse_loss
 | 
			
		||||
        w_quant = ggml_convert_qtype(w, ggml_qtype,
 | 
			
		||||
                                     device=device,
 | 
			
		||||
                                     convert_shape_only=self.convert_shape_only)
 | 
			
		||||
        w_dequant = ggml_convert_fp32(w_quant, w.shape,
 | 
			
		||||
                                      reduce(mul, w.shape, 1), ggml_qtype)
 | 
			
		||||
        mse = mse_loss(w_dequant, w)
 | 
			
		||||
        return mse, w_quant
 | 
			
		||||
 | 
			
		||||
    def quantize(self, device=None):
 | 
			
		||||
        if not self.quantized:
 | 
			
		||||
            w = self.data.contiguous().float()
 | 
			
		||||
| 
						 | 
				
			
			@ -206,25 +217,31 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
                    # save/load
 | 
			
		||||
                    self.qtype = SYM_INT4
 | 
			
		||||
                else:
 | 
			
		||||
                    from torch.nn.functional import mse_loss
 | 
			
		||||
                    w_quant_q4_0 = ggml_convert_qtype(w, SYM_INT4,
 | 
			
		||||
                                                      device=device,
 | 
			
		||||
                                                      convert_shape_only=self.convert_shape_only)
 | 
			
		||||
                    w_q4_0_dequant = ggml_convert_fp32(w_quant_q4_0, w.shape,
 | 
			
		||||
                                                       reduce(mul, w.shape, 1), SYM_INT4)
 | 
			
		||||
                    w_quant_fp4 = ggml_convert_qtype(w, FP4,
 | 
			
		||||
                                                     device=device,
 | 
			
		||||
                                                     convert_shape_only=self.convert_shape_only)
 | 
			
		||||
                    w_fp4_dequant = ggml_convert_fp32(w_quant_fp4, w.shape,
 | 
			
		||||
                                                      reduce(mul, w.shape, 1), FP4)
 | 
			
		||||
                    q4_0_mse = mse_loss(w_q4_0_dequant, w)
 | 
			
		||||
                    fp4_mse = mse_loss(w_fp4_dequant, w)
 | 
			
		||||
                    q4_0_mse, w_quant_q4_0 = self.ggml_mse(w, SYM_INT4, device=device)
 | 
			
		||||
                    fp4_mse, w_quant_fp4 = self.ggml_mse(w, FP4, device=device)
 | 
			
		||||
                    if q4_0_mse <= fp4_mse:
 | 
			
		||||
                        self.qtype = SYM_INT4
 | 
			
		||||
                        self.data = w_quant_q4_0
 | 
			
		||||
                    else:
 | 
			
		||||
                        self.qtype = FP4
 | 
			
		||||
                        self.data = w_quant_fp4
 | 
			
		||||
            elif self.qtype == MOFQ8:
 | 
			
		||||
                if device == 'meta':
 | 
			
		||||
                    w_quantized = ggml_convert_qtype(w, SYM_INT8,
 | 
			
		||||
                                                     device=device,
 | 
			
		||||
                                                     convert_shape_only=self.convert_shape_only)
 | 
			
		||||
                    # TODO: should load from config, the current implementation doesn't support
 | 
			
		||||
                    # save/load
 | 
			
		||||
                    self.qtype = SYM_INT8
 | 
			
		||||
                else:
 | 
			
		||||
                    q8_0_mse, w_quant_q8_0 = self.ggml_mse(w, SYM_INT8, device=device)
 | 
			
		||||
                    fp8_mse, w_quant_fp8 = self.ggml_mse(w, FP8, device=device)
 | 
			
		||||
                    if q8_0_mse <= fp8_mse:
 | 
			
		||||
                        self.qtype = SYM_INT8
 | 
			
		||||
                        self.data = w_quant_q8_0
 | 
			
		||||
                    else:
 | 
			
		||||
                        self.qtype = FP8
 | 
			
		||||
                        self.data = w_quant_fp8
 | 
			
		||||
            else:
 | 
			
		||||
                w_quantized = ggml_convert_qtype(w, self.qtype,
 | 
			
		||||
                                                 device=device,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -113,7 +113,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        invalidInputError(q_k in ggml_tensor_qtype,
 | 
			
		||||
                          f"Unknown load_in_low_bit value: {q_k}, expected:"
 | 
			
		||||
                          f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
 | 
			
		||||
                          "fp4, fp8, fp16 or mixed_4bit.")
 | 
			
		||||
                          "fp4, fp8, fp16, mixed_fp4 or mixed_fp8.")
 | 
			
		||||
        qtype = ggml_tensor_qtype[q_k]
 | 
			
		||||
 | 
			
		||||
        # In case it needs a second try,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -25,7 +25,8 @@ NF4 = ggml_tensor_qtype["nf4"]
 | 
			
		|||
NF3 = ggml_tensor_qtype["nf3"]
 | 
			
		||||
FP8 = ggml_tensor_qtype["fp8"]
 | 
			
		||||
FP4 = ggml_tensor_qtype["fp4"]
 | 
			
		||||
MOFQ4 = ggml_tensor_qtype["mixed_4bit"]
 | 
			
		||||
MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
 | 
			
		||||
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class XMXChecker:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue