First token lm_head optimization (#10318)
* add lm head linear * update * address comments and fix style * address comment
This commit is contained in:
parent
7cf01e6ec8
commit
f5d65203c0
2 changed files with 37 additions and 3 deletions
|
|
@ -211,6 +211,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
|
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
|
||||||
not isinstance(module, LowBitLinear)):
|
not isinstance(module, LowBitLinear)):
|
||||||
in_features, out_features, mp_group = linear_args
|
in_features, out_features, mp_group = linear_args
|
||||||
|
optimize_lm_head = False
|
||||||
|
if name == "lm_head":
|
||||||
|
if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD",
|
||||||
|
None) == "1":
|
||||||
|
optimize_lm_head = True
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
new_linear = None
|
new_linear = None
|
||||||
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
|
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
|
||||||
|
|
@ -225,6 +230,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
bias=has_bias,
|
bias=has_bias,
|
||||||
mp_group=mp_group,
|
mp_group=mp_group,
|
||||||
enable_xetla=enable_xetla,
|
enable_xetla=enable_xetla,
|
||||||
|
optimize_lm_head=optimize_lm_head
|
||||||
)
|
)
|
||||||
device = module.qweight.data.device
|
device = module.qweight.data.device
|
||||||
invalidInputError(device.type != "meta",
|
invalidInputError(device.type != "meta",
|
||||||
|
|
@ -253,6 +259,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
module.bias is not None,
|
module.bias is not None,
|
||||||
mp_group=mp_group,
|
mp_group=mp_group,
|
||||||
enable_xetla=enable_xetla,
|
enable_xetla=enable_xetla,
|
||||||
|
optimize_lm_head=optimize_lm_head
|
||||||
)
|
)
|
||||||
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
|
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
|
||||||
full_module_name,
|
full_module_name,
|
||||||
|
|
@ -280,6 +287,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
out_features,
|
out_features,
|
||||||
module.bias is not None,
|
module.bias is not None,
|
||||||
mp_group=mp_group,
|
mp_group=mp_group,
|
||||||
|
optimize_lm_head=optimize_lm_head
|
||||||
)
|
)
|
||||||
device = module.weight.data.device
|
device = module.weight.data.device
|
||||||
from bigdl.llm.transformers.utils import get_ipex_version
|
from bigdl.llm.transformers.utils import get_ipex_version
|
||||||
|
|
@ -301,6 +309,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
out_features,
|
out_features,
|
||||||
module.bias is not None,
|
module.bias is not None,
|
||||||
mp_group=mp_group,
|
mp_group=mp_group,
|
||||||
|
optimize_lm_head=optimize_lm_head
|
||||||
)
|
)
|
||||||
device = module.weight.data.device
|
device = module.weight.data.device
|
||||||
# convert here
|
# convert here
|
||||||
|
|
|
||||||
|
|
@ -246,6 +246,16 @@ def ggml_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int, qtype:
|
||||||
return dst_tensor
|
return dst_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_lm_head_input(x):
|
||||||
|
if x.dim() > 3:
|
||||||
|
x = x.reshape([-1, x.shape[-2], x.shape[-1]])
|
||||||
|
shape = list(x.size())
|
||||||
|
if shape[1] > 10:
|
||||||
|
shape[1] = 1
|
||||||
|
x = x[:, -1, :].view(shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
# Rename to FP4Params to trigger initializing
|
# Rename to FP4Params to trigger initializing
|
||||||
# the params layer with all parameters on the CPU
|
# the params layer with all parameters on the CPU
|
||||||
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333
|
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333
|
||||||
|
|
@ -505,7 +515,8 @@ class MatMulLowBitCPU(torch.autograd.Function):
|
||||||
|
|
||||||
class LowBitLinear(nn.Linear):
|
class LowBitLinear(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, qtype, bias=True,
|
def __init__(self, input_features, output_features, qtype, bias=True,
|
||||||
conver_to_half=True, mp_group=None, enable_xetla=False):
|
conver_to_half=True, mp_group=None, enable_xetla=False,
|
||||||
|
optimize_lm_head=False):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
self.weight = FP4Params(self.weight.data,
|
self.weight = FP4Params(self.weight.data,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
|
|
@ -520,6 +531,7 @@ class LowBitLinear(nn.Linear):
|
||||||
self.mp_group = mp_group
|
self.mp_group = mp_group
|
||||||
self.compute_dtype = None # only for training
|
self.compute_dtype = None # only for training
|
||||||
self.enable_xetla = enable_xetla
|
self.enable_xetla = enable_xetla
|
||||||
|
self.optimize_lm_head = optimize_lm_head
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
|
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
|
||||||
|
|
@ -536,6 +548,9 @@ class LowBitLinear(nn.Linear):
|
||||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
self.bias.data = self.bias.data.to(x.dtype)
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
|
if self.optimize_lm_head:
|
||||||
|
x = reshape_lm_head_input(x)
|
||||||
|
|
||||||
# [batch, input_num, in_len]
|
# [batch, input_num, in_len]
|
||||||
# input_num == token num for Transformer
|
# input_num == token num for Transformer
|
||||||
x_shape = x.shape
|
x_shape = x.shape
|
||||||
|
|
@ -632,7 +647,8 @@ class LowBitLinear(nn.Linear):
|
||||||
|
|
||||||
class FP16Linear(nn.Linear):
|
class FP16Linear(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True,
|
def __init__(self, input_features, output_features, bias=True,
|
||||||
mp_group=None, weight_type=1):
|
mp_group=None, weight_type=1,
|
||||||
|
optimize_lm_head=False):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
self.in_len = input_features
|
self.in_len = input_features
|
||||||
self.out_len = output_features
|
self.out_len = output_features
|
||||||
|
|
@ -644,11 +660,15 @@ class FP16Linear(nn.Linear):
|
||||||
# weigh_type = 2 means weight has been transposed
|
# weigh_type = 2 means weight has been transposed
|
||||||
# weigh_type = 3 means weight has been transposed by esimd method
|
# weigh_type = 3 means weight has been transposed by esimd method
|
||||||
self.weight_type = 1
|
self.weight_type = 1
|
||||||
|
self.optimize_lm_head = optimize_lm_head
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
# only work for GPU
|
# only work for GPU
|
||||||
invalidInputError(x.device.type == "xpu",
|
invalidInputError(x.device.type == "xpu",
|
||||||
"FP16Linear only works for Intel GPUs")
|
"FP16Linear only works for Intel GPUs")
|
||||||
|
if self.optimize_lm_head:
|
||||||
|
x = reshape_lm_head_input(x)
|
||||||
|
|
||||||
x = x.to(torch.float16)
|
x = x.to(torch.float16)
|
||||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
self.bias.data = self.bias.data.to(x.dtype)
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
@ -743,7 +763,8 @@ class FP16Linear(nn.Linear):
|
||||||
|
|
||||||
class BF16Linear(nn.Linear):
|
class BF16Linear(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True,
|
def __init__(self, input_features, output_features, bias=True,
|
||||||
mp_group=None, compute_dtype=None):
|
mp_group=None, compute_dtype=None,
|
||||||
|
optimize_lm_head=False):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
self.in_len = input_features
|
self.in_len = input_features
|
||||||
self.out_len = output_features
|
self.out_len = output_features
|
||||||
|
|
@ -752,8 +773,12 @@ class BF16Linear(nn.Linear):
|
||||||
self.qtype = ggml_tensor_qtype["bf16"]
|
self.qtype = ggml_tensor_qtype["bf16"]
|
||||||
self.mp_group = mp_group
|
self.mp_group = mp_group
|
||||||
self.compute_dtype = compute_dtype
|
self.compute_dtype = compute_dtype
|
||||||
|
self.optimize_lm_head = optimize_lm_head
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
|
if self.optimize_lm_head:
|
||||||
|
x = reshape_lm_head_input(x)
|
||||||
|
|
||||||
x = x.to(torch.bfloat16)
|
x = x.to(torch.bfloat16)
|
||||||
if self.weight is not None and self.weight.dtype != x.dtype:
|
if self.weight is not None and self.weight.dtype != x.dtype:
|
||||||
self.weight.data = self.weight.data.to(x.dtype)
|
self.weight.data = self.weight.data.to(x.dtype)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue