LLM: basic api support for esimd fp16 (#9067)
* basic api support for fp16 * fix style * fix * fix error and style * fix style * meet code review * update based on comments
This commit is contained in:
		
							parent
							
								
									65373d2a8b
								
							
						
					
					
						commit
						f64257a093
					
				
					 4 changed files with 103 additions and 28 deletions
				
			
		| 
						 | 
				
			
			@ -31,7 +31,8 @@ ggml_tensor_qtype = {"sym_int4": 2,   # q4_0 in ggml
 | 
			
		|||
                     "asym_int5": 7,  # q5_1 in ggml
 | 
			
		||||
                     "sym_int8": 8,   # q8_0 in ggml
 | 
			
		||||
                     "nf4": 10,
 | 
			
		||||
                     "nf3": 11}
 | 
			
		||||
                     "nf3": 11,
 | 
			
		||||
                     "fp16": 12}
 | 
			
		||||
 | 
			
		||||
_llama_quantize_type = {"q4_0": 2,
 | 
			
		||||
                        "q4_1": 3,
 | 
			
		||||
| 
						 | 
				
			
			@ -71,7 +72,7 @@ def quantize(input_path: str, output_path: str,
 | 
			
		|||
    :param dtype: Quantization method which differs in the resulting model disk size and
 | 
			
		||||
            inference speed. Defalut to `q4_0`. Difference model family may support
 | 
			
		||||
            different types, now the supported list is:
 | 
			
		||||
            llama : "q4_0", "q4_1", "q4_2"
 | 
			
		||||
            llama : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
 | 
			
		||||
            bloom : "q4_0", "q4_1"
 | 
			
		||||
            gptneox : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
 | 
			
		||||
            starcoder : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -41,12 +41,13 @@ from accelerate import init_empty_weights
 | 
			
		|||
import warnings
 | 
			
		||||
import transformers
 | 
			
		||||
import importlib
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from .utils import logger
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		||||
                                 current_key_name=None, convert_shape_only=False):
 | 
			
		||||
    from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params
 | 
			
		||||
    from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear
 | 
			
		||||
    has_been_replaced = False
 | 
			
		||||
 | 
			
		||||
    for name, module in model.named_children():
 | 
			
		||||
| 
						 | 
				
			
			@ -57,33 +58,55 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
            # Check if the current key is not in the `modules_to_not_convert`
 | 
			
		||||
            if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
 | 
			
		||||
                with init_empty_weights():
 | 
			
		||||
                    new_linear = LowBitLinear(
 | 
			
		||||
                        module.in_features,
 | 
			
		||||
                        module.out_features,
 | 
			
		||||
                        qtype,
 | 
			
		||||
                        module.bias is not None,
 | 
			
		||||
                    )
 | 
			
		||||
                    new_linear = None
 | 
			
		||||
                    if qtype != ggml_tensor_qtype["fp16"]:
 | 
			
		||||
                        new_linear = LowBitLinear(
 | 
			
		||||
                            module.in_features,
 | 
			
		||||
                            module.out_features,
 | 
			
		||||
                            qtype,
 | 
			
		||||
                            module.bias is not None,
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                    device_type = module.weight.data.device.type
 | 
			
		||||
                    # Copy the weights
 | 
			
		||||
                    paramsLowBit = FP4Params(data=module.weight.data,
 | 
			
		||||
                                             requires_grad=False,
 | 
			
		||||
                                             quantized=False,
 | 
			
		||||
                                             _shape=None,
 | 
			
		||||
                                             convert_shape_only=convert_shape_only,
 | 
			
		||||
                                             qtype=qtype).to(device_type)
 | 
			
		||||
                    new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                        device_type = module.weight.data.device.type
 | 
			
		||||
                        # Copy the weights
 | 
			
		||||
                        paramsLowBit = FP4Params(data=module.weight.data,
 | 
			
		||||
                                                 requires_grad=False,
 | 
			
		||||
                                                 quantized=False,
 | 
			
		||||
                                                 _shape=None,
 | 
			
		||||
                                                 convert_shape_only=convert_shape_only,
 | 
			
		||||
                                                 qtype=qtype).to(device_type)
 | 
			
		||||
                        new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                    else:
 | 
			
		||||
                        #  only support two size now
 | 
			
		||||
                        #  may generalize to other sizes
 | 
			
		||||
                        if module.in_features in [4096, 11008]:
 | 
			
		||||
                            # esimd fp16 path
 | 
			
		||||
                            new_linear = FP16Linear(
 | 
			
		||||
                                module.in_features,
 | 
			
		||||
                                module.out_features,
 | 
			
		||||
                                qtype,
 | 
			
		||||
                                module.bias is not None,
 | 
			
		||||
                            )
 | 
			
		||||
                            device_type = module.weight.data.device.type
 | 
			
		||||
 | 
			
		||||
                    if module.bias is not None:
 | 
			
		||||
                        new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                            .to(device_type)
 | 
			
		||||
                            # convert here
 | 
			
		||||
                            m, n = module.weight.data.shape
 | 
			
		||||
                            trans_weight = module.weight.data.reshape(m//16, 16, n)
 | 
			
		||||
                            trans_weight = trans_weight.transpose(1, 2).contiguous()
 | 
			
		||||
                            new_linear._parameters['weight'] = nn.Parameter(trans_weight)
 | 
			
		||||
 | 
			
		||||
                    model._modules[name] = new_linear
 | 
			
		||||
                    has_been_replaced = True
 | 
			
		||||
                    # Force requires grad to False to avoid unexpected errors
 | 
			
		||||
                    model._modules[name].requires_grad_(False)
 | 
			
		||||
                    #  fp16 may generalize to other sizes later
 | 
			
		||||
                    if new_linear is not None:
 | 
			
		||||
                        if module.bias is not None:
 | 
			
		||||
                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                                .to(device_type)
 | 
			
		||||
 | 
			
		||||
                    module.weight = None
 | 
			
		||||
                        model._modules[name] = new_linear
 | 
			
		||||
                        has_been_replaced = True
 | 
			
		||||
                        # Force requires grad to False to avoid unexpected errors
 | 
			
		||||
                        model._modules[name].requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
                        module.weight = None
 | 
			
		||||
 | 
			
		||||
        # Remove the last key for recursion
 | 
			
		||||
        if len(list(module.children())) > 0:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -378,3 +378,53 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
                    result += self.bias
 | 
			
		||||
 | 
			
		||||
        return result.to(x.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FP16Linear(nn.Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, qtype, bias=True,
 | 
			
		||||
                 conver_to_half=True):
 | 
			
		||||
        super().__init__(input_features, output_features, bias)
 | 
			
		||||
        self.in_len = input_features
 | 
			
		||||
        self.out_len = output_features
 | 
			
		||||
        self.weight_shape = (self.out_len, self.in_len)
 | 
			
		||||
        self.weight_length = self.out_len * self.in_len
 | 
			
		||||
        self.qtype = qtype
 | 
			
		||||
        self.conver_to_half = conver_to_half
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        if self.bias is not None and self.bias.dtype != x.dtype:
 | 
			
		||||
            self.bias.data = self.bias.data.to(x.dtype)
 | 
			
		||||
 | 
			
		||||
        x_shape = x.shape
 | 
			
		||||
        x_2d = x.view(-1, x_shape[-1])
 | 
			
		||||
 | 
			
		||||
        x0 = self.weight.data
 | 
			
		||||
        # only work for GPU
 | 
			
		||||
        invalidInputError(x0.device.type == "xpu",
 | 
			
		||||
                          "FP16 only works for GPU")
 | 
			
		||||
        try:
 | 
			
		||||
            import intel_extension_for_pytorch
 | 
			
		||||
            import linear_fp16_esimd
 | 
			
		||||
        except ModuleNotFoundError:
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              "Please `pip install bigdl_core_xe` first.")
 | 
			
		||||
 | 
			
		||||
        if x_2d.is_contiguous() is False:
 | 
			
		||||
            x_2d = x_2d.contiguous()
 | 
			
		||||
 | 
			
		||||
        if x_2d.shape[0] > 1:
 | 
			
		||||
            # first token or batch size > 1, re-convert weight
 | 
			
		||||
            original_weight = self.weight.data.transpose(1, 2)
 | 
			
		||||
            original_weight = original_weight.reshape(self.out_len, self.in_len)
 | 
			
		||||
            result = F.linear(x_2d, original_weight.contiguous())
 | 
			
		||||
            del original_weight
 | 
			
		||||
        else:
 | 
			
		||||
            # rest token, use esimd optimization
 | 
			
		||||
            result = linear_fp16_esimd.forward(x_2d, self.weight.data)
 | 
			
		||||
 | 
			
		||||
        new_shape = x_shape[:-1] + (self.out_len,)
 | 
			
		||||
        result = result.view(new_shape)
 | 
			
		||||
        if self.bias is not None:
 | 
			
		||||
            result += self.bias
 | 
			
		||||
 | 
			
		||||
        return result.to(x.dtype)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -60,7 +60,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        :param load_in_4bit: boolean value, True means load linear's weight to symmetric int 4.
 | 
			
		||||
                             Default to be False.
 | 
			
		||||
        :param load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5
 | 
			
		||||
                                or sym_int8. sym_int4 means symmetric int 4, asym_int4 means
 | 
			
		||||
                                , sym_int8 or fp16. sym_int4 means symmetric int 4, asym_int4 means
 | 
			
		||||
                                asymmetric int 4, etc. Relevant low bit optimizations will
 | 
			
		||||
                                be applied to the model.
 | 
			
		||||
        :param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
 | 
			
		||||
| 
						 | 
				
			
			@ -104,8 +104,9 @@ class _BaseAutoModelClass:
 | 
			
		|||
        from .convert import ggml_convert_low_bit
 | 
			
		||||
        invalidInputError(q_k in ggml_tensor_qtype,
 | 
			
		||||
                          f"Unknown load_in_low_bit value: {q_k}, expected:"
 | 
			
		||||
                          f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
 | 
			
		||||
                          f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8 or fp16.")
 | 
			
		||||
        qtype = ggml_tensor_qtype[q_k]
 | 
			
		||||
 | 
			
		||||
        # In case it needs a second try,
 | 
			
		||||
        # `from_pretrained`` may pop items out in dict
 | 
			
		||||
        # and lead to args missing.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue