First token lm_head optimization (#10318)
* add lm head linear * update * address comments and fix style * address comment
This commit is contained in:
		
							parent
							
								
									7cf01e6ec8
								
							
						
					
					
						commit
						f5d65203c0
					
				
					 2 changed files with 37 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -211,6 +211,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
            if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
 | 
			
		||||
                    not isinstance(module, LowBitLinear)):
 | 
			
		||||
                in_features, out_features, mp_group = linear_args
 | 
			
		||||
                optimize_lm_head = False
 | 
			
		||||
                if name == "lm_head":
 | 
			
		||||
                    if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD",
 | 
			
		||||
                                                                          None) == "1":
 | 
			
		||||
                        optimize_lm_head = True
 | 
			
		||||
                with init_empty_weights():
 | 
			
		||||
                    new_linear = None
 | 
			
		||||
                    is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
 | 
			
		||||
| 
						 | 
				
			
			@ -225,6 +230,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                            bias=has_bias,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                            enable_xetla=enable_xetla,
 | 
			
		||||
                            optimize_lm_head=optimize_lm_head
 | 
			
		||||
                        )
 | 
			
		||||
                        device = module.qweight.data.device
 | 
			
		||||
                        invalidInputError(device.type != "meta",
 | 
			
		||||
| 
						 | 
				
			
			@ -253,6 +259,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                            module.bias is not None,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                            enable_xetla=enable_xetla,
 | 
			
		||||
                            optimize_lm_head=optimize_lm_head
 | 
			
		||||
                        )
 | 
			
		||||
                        cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
 | 
			
		||||
                                                                           full_module_name,
 | 
			
		||||
| 
						 | 
				
			
			@ -280,6 +287,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                            out_features,
 | 
			
		||||
                            module.bias is not None,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                            optimize_lm_head=optimize_lm_head
 | 
			
		||||
                        )
 | 
			
		||||
                        device = module.weight.data.device
 | 
			
		||||
                        from bigdl.llm.transformers.utils import get_ipex_version
 | 
			
		||||
| 
						 | 
				
			
			@ -301,6 +309,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                            out_features,
 | 
			
		||||
                            module.bias is not None,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                            optimize_lm_head=optimize_lm_head
 | 
			
		||||
                        )
 | 
			
		||||
                        device = module.weight.data.device
 | 
			
		||||
                        # convert here
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -246,6 +246,16 @@ def ggml_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int, qtype:
 | 
			
		|||
    return dst_tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def reshape_lm_head_input(x):
 | 
			
		||||
    if x.dim() > 3:
 | 
			
		||||
        x = x.reshape([-1, x.shape[-2], x.shape[-1]])
 | 
			
		||||
    shape = list(x.size())
 | 
			
		||||
    if shape[1] > 10:
 | 
			
		||||
        shape[1] = 1
 | 
			
		||||
        x = x[:, -1, :].view(shape)
 | 
			
		||||
    return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Rename to FP4Params to trigger initializing
 | 
			
		||||
# the params layer with all parameters on the CPU
 | 
			
		||||
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333
 | 
			
		||||
| 
						 | 
				
			
			@ -505,7 +515,8 @@ class MatMulLowBitCPU(torch.autograd.Function):
 | 
			
		|||
 | 
			
		||||
class LowBitLinear(nn.Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, qtype, bias=True,
 | 
			
		||||
                 conver_to_half=True, mp_group=None, enable_xetla=False):
 | 
			
		||||
                 conver_to_half=True, mp_group=None, enable_xetla=False,
 | 
			
		||||
                 optimize_lm_head=False):
 | 
			
		||||
        super().__init__(input_features, output_features, bias)
 | 
			
		||||
        self.weight = FP4Params(self.weight.data,
 | 
			
		||||
                                requires_grad=False,
 | 
			
		||||
| 
						 | 
				
			
			@ -520,6 +531,7 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
        self.mp_group = mp_group
 | 
			
		||||
        self.compute_dtype = None  # only for training
 | 
			
		||||
        self.enable_xetla = enable_xetla
 | 
			
		||||
        self.optimize_lm_head = optimize_lm_head
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        # Due to inconsistent training status in some models like Baichuan-7b-Chat,
 | 
			
		||||
| 
						 | 
				
			
			@ -536,6 +548,9 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
        if self.bias is not None and self.bias.dtype != x.dtype:
 | 
			
		||||
            self.bias.data = self.bias.data.to(x.dtype)
 | 
			
		||||
 | 
			
		||||
        if self.optimize_lm_head:
 | 
			
		||||
            x = reshape_lm_head_input(x)
 | 
			
		||||
 | 
			
		||||
        # [batch, input_num, in_len]
 | 
			
		||||
        # input_num == token num for Transformer
 | 
			
		||||
        x_shape = x.shape
 | 
			
		||||
| 
						 | 
				
			
			@ -632,7 +647,8 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
 | 
			
		||||
class FP16Linear(nn.Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, bias=True,
 | 
			
		||||
                 mp_group=None, weight_type=1):
 | 
			
		||||
                 mp_group=None, weight_type=1,
 | 
			
		||||
                 optimize_lm_head=False):
 | 
			
		||||
        super().__init__(input_features, output_features, bias)
 | 
			
		||||
        self.in_len = input_features
 | 
			
		||||
        self.out_len = output_features
 | 
			
		||||
| 
						 | 
				
			
			@ -644,11 +660,15 @@ class FP16Linear(nn.Linear):
 | 
			
		|||
        # weigh_type = 2 means weight has been transposed
 | 
			
		||||
        # weigh_type = 3 means weight has been transposed by esimd method
 | 
			
		||||
        self.weight_type = 1
 | 
			
		||||
        self.optimize_lm_head = optimize_lm_head
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        # only work for GPU
 | 
			
		||||
        invalidInputError(x.device.type == "xpu",
 | 
			
		||||
                          "FP16Linear only works for Intel GPUs")
 | 
			
		||||
        if self.optimize_lm_head:
 | 
			
		||||
            x = reshape_lm_head_input(x)
 | 
			
		||||
 | 
			
		||||
        x = x.to(torch.float16)
 | 
			
		||||
        if self.bias is not None and self.bias.dtype != x.dtype:
 | 
			
		||||
                self.bias.data = self.bias.data.to(x.dtype)
 | 
			
		||||
| 
						 | 
				
			
			@ -743,7 +763,8 @@ class FP16Linear(nn.Linear):
 | 
			
		|||
 | 
			
		||||
class BF16Linear(nn.Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, bias=True,
 | 
			
		||||
                 mp_group=None, compute_dtype=None):
 | 
			
		||||
                 mp_group=None, compute_dtype=None,
 | 
			
		||||
                 optimize_lm_head=False):
 | 
			
		||||
        super().__init__(input_features, output_features, bias)
 | 
			
		||||
        self.in_len = input_features
 | 
			
		||||
        self.out_len = output_features
 | 
			
		||||
| 
						 | 
				
			
			@ -752,8 +773,12 @@ class BF16Linear(nn.Linear):
 | 
			
		|||
        self.qtype = ggml_tensor_qtype["bf16"]
 | 
			
		||||
        self.mp_group = mp_group
 | 
			
		||||
        self.compute_dtype = compute_dtype
 | 
			
		||||
        self.optimize_lm_head = optimize_lm_head
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        if self.optimize_lm_head:
 | 
			
		||||
            x = reshape_lm_head_input(x)
 | 
			
		||||
 | 
			
		||||
        x = x.to(torch.bfloat16)
 | 
			
		||||
        if self.weight is not None and self.weight.dtype != x.dtype:
 | 
			
		||||
            self.weight.data = self.weight.data.to(x.dtype)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue