First token lm_head optimization (#10318)

* add lm head linear

* update

* address comments and fix style

* address comment
This commit is contained in:
Yina Chen 2024-03-13 10:11:32 +08:00 committed by GitHub
parent 7cf01e6ec8
commit f5d65203c0
2 changed files with 37 additions and 3 deletions

View file

@ -211,6 +211,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
not isinstance(module, LowBitLinear)):
in_features, out_features, mp_group = linear_args
optimize_lm_head = False
if name == "lm_head":
if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD",
None) == "1":
optimize_lm_head = True
with init_empty_weights():
new_linear = None
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
@ -225,6 +230,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
bias=has_bias,
mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head
)
device = module.qweight.data.device
invalidInputError(device.type != "meta",
@ -253,6 +259,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
module.bias is not None,
mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head
)
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
full_module_name,
@ -280,6 +287,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
out_features,
module.bias is not None,
mp_group=mp_group,
optimize_lm_head=optimize_lm_head
)
device = module.weight.data.device
from bigdl.llm.transformers.utils import get_ipex_version
@ -301,6 +309,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
out_features,
module.bias is not None,
mp_group=mp_group,
optimize_lm_head=optimize_lm_head
)
device = module.weight.data.device
# convert here

View file

@ -246,6 +246,16 @@ def ggml_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int, qtype:
return dst_tensor
def reshape_lm_head_input(x):
if x.dim() > 3:
x = x.reshape([-1, x.shape[-2], x.shape[-1]])
shape = list(x.size())
if shape[1] > 10:
shape[1] = 1
x = x[:, -1, :].view(shape)
return x
# Rename to FP4Params to trigger initializing
# the params layer with all parameters on the CPU
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333
@ -505,7 +515,8 @@ class MatMulLowBitCPU(torch.autograd.Function):
class LowBitLinear(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True,
conver_to_half=True, mp_group=None, enable_xetla=False):
conver_to_half=True, mp_group=None, enable_xetla=False,
optimize_lm_head=False):
super().__init__(input_features, output_features, bias)
self.weight = FP4Params(self.weight.data,
requires_grad=False,
@ -520,6 +531,7 @@ class LowBitLinear(nn.Linear):
self.mp_group = mp_group
self.compute_dtype = None # only for training
self.enable_xetla = enable_xetla
self.optimize_lm_head = optimize_lm_head
def forward(self, x: torch.Tensor):
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
@ -536,6 +548,9 @@ class LowBitLinear(nn.Linear):
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if self.optimize_lm_head:
x = reshape_lm_head_input(x)
# [batch, input_num, in_len]
# input_num == token num for Transformer
x_shape = x.shape
@ -632,7 +647,8 @@ class LowBitLinear(nn.Linear):
class FP16Linear(nn.Linear):
def __init__(self, input_features, output_features, bias=True,
mp_group=None, weight_type=1):
mp_group=None, weight_type=1,
optimize_lm_head=False):
super().__init__(input_features, output_features, bias)
self.in_len = input_features
self.out_len = output_features
@ -644,11 +660,15 @@ class FP16Linear(nn.Linear):
# weigh_type = 2 means weight has been transposed
# weigh_type = 3 means weight has been transposed by esimd method
self.weight_type = 1
self.optimize_lm_head = optimize_lm_head
def forward(self, x: torch.Tensor):
# only work for GPU
invalidInputError(x.device.type == "xpu",
"FP16Linear only works for Intel GPUs")
if self.optimize_lm_head:
x = reshape_lm_head_input(x)
x = x.to(torch.float16)
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
@ -743,7 +763,8 @@ class FP16Linear(nn.Linear):
class BF16Linear(nn.Linear):
def __init__(self, input_features, output_features, bias=True,
mp_group=None, compute_dtype=None):
mp_group=None, compute_dtype=None,
optimize_lm_head=False):
super().__init__(input_features, output_features, bias)
self.in_len = input_features
self.out_len = output_features
@ -752,8 +773,12 @@ class BF16Linear(nn.Linear):
self.qtype = ggml_tensor_qtype["bf16"]
self.mp_group = mp_group
self.compute_dtype = compute_dtype
self.optimize_lm_head = optimize_lm_head
def forward(self, x: torch.Tensor):
if self.optimize_lm_head:
x = reshape_lm_head_input(x)
x = x.to(torch.bfloat16)
if self.weight is not None and self.weight.dtype != x.dtype:
self.weight.data = self.weight.data.to(x.dtype)