First token lm_head optimization (#10318)

* add lm head linear

* update

* address comments and fix style

* address comment
This commit is contained in:
Yina Chen 2024-03-13 10:11:32 +08:00 committed by GitHub
parent 7cf01e6ec8
commit f5d65203c0
2 changed files with 37 additions and 3 deletions

View file

@ -211,6 +211,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
not isinstance(module, LowBitLinear)): not isinstance(module, LowBitLinear)):
in_features, out_features, mp_group = linear_args in_features, out_features, mp_group = linear_args
optimize_lm_head = False
if name == "lm_head":
if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD",
None) == "1":
optimize_lm_head = True
with init_empty_weights(): with init_empty_weights():
new_linear = None new_linear = None
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld) is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
@ -225,6 +230,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
bias=has_bias, bias=has_bias,
mp_group=mp_group, mp_group=mp_group,
enable_xetla=enable_xetla, enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head
) )
device = module.qweight.data.device device = module.qweight.data.device
invalidInputError(device.type != "meta", invalidInputError(device.type != "meta",
@ -253,6 +259,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
module.bias is not None, module.bias is not None,
mp_group=mp_group, mp_group=mp_group,
enable_xetla=enable_xetla, enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head
) )
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype, cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
full_module_name, full_module_name,
@ -280,6 +287,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
out_features, out_features,
module.bias is not None, module.bias is not None,
mp_group=mp_group, mp_group=mp_group,
optimize_lm_head=optimize_lm_head
) )
device = module.weight.data.device device = module.weight.data.device
from bigdl.llm.transformers.utils import get_ipex_version from bigdl.llm.transformers.utils import get_ipex_version
@ -301,6 +309,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
out_features, out_features,
module.bias is not None, module.bias is not None,
mp_group=mp_group, mp_group=mp_group,
optimize_lm_head=optimize_lm_head
) )
device = module.weight.data.device device = module.weight.data.device
# convert here # convert here

View file

@ -246,6 +246,16 @@ def ggml_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int, qtype:
return dst_tensor return dst_tensor
def reshape_lm_head_input(x):
if x.dim() > 3:
x = x.reshape([-1, x.shape[-2], x.shape[-1]])
shape = list(x.size())
if shape[1] > 10:
shape[1] = 1
x = x[:, -1, :].view(shape)
return x
# Rename to FP4Params to trigger initializing # Rename to FP4Params to trigger initializing
# the params layer with all parameters on the CPU # the params layer with all parameters on the CPU
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333 # https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333
@ -505,7 +515,8 @@ class MatMulLowBitCPU(torch.autograd.Function):
class LowBitLinear(nn.Linear): class LowBitLinear(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True, def __init__(self, input_features, output_features, qtype, bias=True,
conver_to_half=True, mp_group=None, enable_xetla=False): conver_to_half=True, mp_group=None, enable_xetla=False,
optimize_lm_head=False):
super().__init__(input_features, output_features, bias) super().__init__(input_features, output_features, bias)
self.weight = FP4Params(self.weight.data, self.weight = FP4Params(self.weight.data,
requires_grad=False, requires_grad=False,
@ -520,6 +531,7 @@ class LowBitLinear(nn.Linear):
self.mp_group = mp_group self.mp_group = mp_group
self.compute_dtype = None # only for training self.compute_dtype = None # only for training
self.enable_xetla = enable_xetla self.enable_xetla = enable_xetla
self.optimize_lm_head = optimize_lm_head
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# Due to inconsistent training status in some models like Baichuan-7b-Chat, # Due to inconsistent training status in some models like Baichuan-7b-Chat,
@ -536,6 +548,9 @@ class LowBitLinear(nn.Linear):
if self.bias is not None and self.bias.dtype != x.dtype: if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype) self.bias.data = self.bias.data.to(x.dtype)
if self.optimize_lm_head:
x = reshape_lm_head_input(x)
# [batch, input_num, in_len] # [batch, input_num, in_len]
# input_num == token num for Transformer # input_num == token num for Transformer
x_shape = x.shape x_shape = x.shape
@ -632,7 +647,8 @@ class LowBitLinear(nn.Linear):
class FP16Linear(nn.Linear): class FP16Linear(nn.Linear):
def __init__(self, input_features, output_features, bias=True, def __init__(self, input_features, output_features, bias=True,
mp_group=None, weight_type=1): mp_group=None, weight_type=1,
optimize_lm_head=False):
super().__init__(input_features, output_features, bias) super().__init__(input_features, output_features, bias)
self.in_len = input_features self.in_len = input_features
self.out_len = output_features self.out_len = output_features
@ -644,11 +660,15 @@ class FP16Linear(nn.Linear):
# weigh_type = 2 means weight has been transposed # weigh_type = 2 means weight has been transposed
# weigh_type = 3 means weight has been transposed by esimd method # weigh_type = 3 means weight has been transposed by esimd method
self.weight_type = 1 self.weight_type = 1
self.optimize_lm_head = optimize_lm_head
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# only work for GPU # only work for GPU
invalidInputError(x.device.type == "xpu", invalidInputError(x.device.type == "xpu",
"FP16Linear only works for Intel GPUs") "FP16Linear only works for Intel GPUs")
if self.optimize_lm_head:
x = reshape_lm_head_input(x)
x = x.to(torch.float16) x = x.to(torch.float16)
if self.bias is not None and self.bias.dtype != x.dtype: if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype) self.bias.data = self.bias.data.to(x.dtype)
@ -743,7 +763,8 @@ class FP16Linear(nn.Linear):
class BF16Linear(nn.Linear): class BF16Linear(nn.Linear):
def __init__(self, input_features, output_features, bias=True, def __init__(self, input_features, output_features, bias=True,
mp_group=None, compute_dtype=None): mp_group=None, compute_dtype=None,
optimize_lm_head=False):
super().__init__(input_features, output_features, bias) super().__init__(input_features, output_features, bias)
self.in_len = input_features self.in_len = input_features
self.out_len = output_features self.out_len = output_features
@ -752,8 +773,12 @@ class BF16Linear(nn.Linear):
self.qtype = ggml_tensor_qtype["bf16"] self.qtype = ggml_tensor_qtype["bf16"]
self.mp_group = mp_group self.mp_group = mp_group
self.compute_dtype = compute_dtype self.compute_dtype = compute_dtype
self.optimize_lm_head = optimize_lm_head
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if self.optimize_lm_head:
x = reshape_lm_head_input(x)
x = x.to(torch.bfloat16) x = x.to(torch.bfloat16)
if self.weight is not None and self.weight.dtype != x.dtype: if self.weight is not None and self.weight.dtype != x.dtype:
self.weight.data = self.weight.data.to(x.dtype) self.weight.data = self.weight.data.to(x.dtype)