First token lm_head optimization (#10318)
* add lm head linear * update * address comments and fix style * address comment
This commit is contained in:
parent
7cf01e6ec8
commit
f5d65203c0
2 changed files with 37 additions and 3 deletions
|
|
@ -211,6 +211,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
|
||||
not isinstance(module, LowBitLinear)):
|
||||
in_features, out_features, mp_group = linear_args
|
||||
optimize_lm_head = False
|
||||
if name == "lm_head":
|
||||
if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD",
|
||||
None) == "1":
|
||||
optimize_lm_head = True
|
||||
with init_empty_weights():
|
||||
new_linear = None
|
||||
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
|
||||
|
|
@ -225,6 +230,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
bias=has_bias,
|
||||
mp_group=mp_group,
|
||||
enable_xetla=enable_xetla,
|
||||
optimize_lm_head=optimize_lm_head
|
||||
)
|
||||
device = module.qweight.data.device
|
||||
invalidInputError(device.type != "meta",
|
||||
|
|
@ -253,6 +259,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
enable_xetla=enable_xetla,
|
||||
optimize_lm_head=optimize_lm_head
|
||||
)
|
||||
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
|
||||
full_module_name,
|
||||
|
|
@ -280,6 +287,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
out_features,
|
||||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
optimize_lm_head=optimize_lm_head
|
||||
)
|
||||
device = module.weight.data.device
|
||||
from bigdl.llm.transformers.utils import get_ipex_version
|
||||
|
|
@ -301,6 +309,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
out_features,
|
||||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
optimize_lm_head=optimize_lm_head
|
||||
)
|
||||
device = module.weight.data.device
|
||||
# convert here
|
||||
|
|
|
|||
|
|
@ -246,6 +246,16 @@ def ggml_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int, qtype:
|
|||
return dst_tensor
|
||||
|
||||
|
||||
def reshape_lm_head_input(x):
|
||||
if x.dim() > 3:
|
||||
x = x.reshape([-1, x.shape[-2], x.shape[-1]])
|
||||
shape = list(x.size())
|
||||
if shape[1] > 10:
|
||||
shape[1] = 1
|
||||
x = x[:, -1, :].view(shape)
|
||||
return x
|
||||
|
||||
|
||||
# Rename to FP4Params to trigger initializing
|
||||
# the params layer with all parameters on the CPU
|
||||
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333
|
||||
|
|
@ -505,7 +515,8 @@ class MatMulLowBitCPU(torch.autograd.Function):
|
|||
|
||||
class LowBitLinear(nn.Linear):
|
||||
def __init__(self, input_features, output_features, qtype, bias=True,
|
||||
conver_to_half=True, mp_group=None, enable_xetla=False):
|
||||
conver_to_half=True, mp_group=None, enable_xetla=False,
|
||||
optimize_lm_head=False):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.weight = FP4Params(self.weight.data,
|
||||
requires_grad=False,
|
||||
|
|
@ -520,6 +531,7 @@ class LowBitLinear(nn.Linear):
|
|||
self.mp_group = mp_group
|
||||
self.compute_dtype = None # only for training
|
||||
self.enable_xetla = enable_xetla
|
||||
self.optimize_lm_head = optimize_lm_head
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
|
||||
|
|
@ -536,6 +548,9 @@ class LowBitLinear(nn.Linear):
|
|||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
if self.optimize_lm_head:
|
||||
x = reshape_lm_head_input(x)
|
||||
|
||||
# [batch, input_num, in_len]
|
||||
# input_num == token num for Transformer
|
||||
x_shape = x.shape
|
||||
|
|
@ -632,7 +647,8 @@ class LowBitLinear(nn.Linear):
|
|||
|
||||
class FP16Linear(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True,
|
||||
mp_group=None, weight_type=1):
|
||||
mp_group=None, weight_type=1,
|
||||
optimize_lm_head=False):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.in_len = input_features
|
||||
self.out_len = output_features
|
||||
|
|
@ -644,11 +660,15 @@ class FP16Linear(nn.Linear):
|
|||
# weigh_type = 2 means weight has been transposed
|
||||
# weigh_type = 3 means weight has been transposed by esimd method
|
||||
self.weight_type = 1
|
||||
self.optimize_lm_head = optimize_lm_head
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# only work for GPU
|
||||
invalidInputError(x.device.type == "xpu",
|
||||
"FP16Linear only works for Intel GPUs")
|
||||
if self.optimize_lm_head:
|
||||
x = reshape_lm_head_input(x)
|
||||
|
||||
x = x.to(torch.float16)
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
|
@ -743,7 +763,8 @@ class FP16Linear(nn.Linear):
|
|||
|
||||
class BF16Linear(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True,
|
||||
mp_group=None, compute_dtype=None):
|
||||
mp_group=None, compute_dtype=None,
|
||||
optimize_lm_head=False):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.in_len = input_features
|
||||
self.out_len = output_features
|
||||
|
|
@ -752,8 +773,12 @@ class BF16Linear(nn.Linear):
|
|||
self.qtype = ggml_tensor_qtype["bf16"]
|
||||
self.mp_group = mp_group
|
||||
self.compute_dtype = compute_dtype
|
||||
self.optimize_lm_head = optimize_lm_head
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.optimize_lm_head:
|
||||
x = reshape_lm_head_input(x)
|
||||
|
||||
x = x.to(torch.bfloat16)
|
||||
if self.weight is not None and self.weight.dtype != x.dtype:
|
||||
self.weight.data = self.weight.data.to(x.dtype)
|
||||
|
|
|
|||
Loading…
Reference in a new issue