diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 07b47688..1c9a66c6 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -211,6 +211,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and not isinstance(module, LowBitLinear)): in_features, out_features, mp_group = linear_args + optimize_lm_head = False + if name == "lm_head": + if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD", + None) == "1": + optimize_lm_head = True with init_empty_weights(): new_linear = None is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld) @@ -225,6 +230,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, bias=has_bias, mp_group=mp_group, enable_xetla=enable_xetla, + optimize_lm_head=optimize_lm_head ) device = module.qweight.data.device invalidInputError(device.type != "meta", @@ -253,6 +259,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, module.bias is not None, mp_group=mp_group, enable_xetla=enable_xetla, + optimize_lm_head=optimize_lm_head ) cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype, full_module_name, @@ -280,6 +287,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, out_features, module.bias is not None, mp_group=mp_group, + optimize_lm_head=optimize_lm_head ) device = module.weight.data.device from bigdl.llm.transformers.utils import get_ipex_version @@ -301,6 +309,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, out_features, module.bias is not None, mp_group=mp_group, + optimize_lm_head=optimize_lm_head ) device = module.weight.data.device # convert here diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index bcb3b995..c73eec79 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -246,6 +246,16 @@ def ggml_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int, qtype: return dst_tensor +def reshape_lm_head_input(x): + if x.dim() > 3: + x = x.reshape([-1, x.shape[-2], x.shape[-1]]) + shape = list(x.size()) + if shape[1] > 10: + shape[1] = 1 + x = x[:, -1, :].view(shape) + return x + + # Rename to FP4Params to trigger initializing # the params layer with all parameters on the CPU # https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333 @@ -505,7 +515,8 @@ class MatMulLowBitCPU(torch.autograd.Function): class LowBitLinear(nn.Linear): def __init__(self, input_features, output_features, qtype, bias=True, - conver_to_half=True, mp_group=None, enable_xetla=False): + conver_to_half=True, mp_group=None, enable_xetla=False, + optimize_lm_head=False): super().__init__(input_features, output_features, bias) self.weight = FP4Params(self.weight.data, requires_grad=False, @@ -520,6 +531,7 @@ class LowBitLinear(nn.Linear): self.mp_group = mp_group self.compute_dtype = None # only for training self.enable_xetla = enable_xetla + self.optimize_lm_head = optimize_lm_head def forward(self, x: torch.Tensor): # Due to inconsistent training status in some models like Baichuan-7b-Chat, @@ -536,6 +548,9 @@ class LowBitLinear(nn.Linear): if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) + if self.optimize_lm_head: + x = reshape_lm_head_input(x) + # [batch, input_num, in_len] # input_num == token num for Transformer x_shape = x.shape @@ -632,7 +647,8 @@ class LowBitLinear(nn.Linear): class FP16Linear(nn.Linear): def __init__(self, input_features, output_features, bias=True, - mp_group=None, weight_type=1): + mp_group=None, weight_type=1, + optimize_lm_head=False): super().__init__(input_features, output_features, bias) self.in_len = input_features self.out_len = output_features @@ -644,11 +660,15 @@ class FP16Linear(nn.Linear): # weigh_type = 2 means weight has been transposed # weigh_type = 3 means weight has been transposed by esimd method self.weight_type = 1 + self.optimize_lm_head = optimize_lm_head def forward(self, x: torch.Tensor): # only work for GPU invalidInputError(x.device.type == "xpu", "FP16Linear only works for Intel GPUs") + if self.optimize_lm_head: + x = reshape_lm_head_input(x) + x = x.to(torch.float16) if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) @@ -743,7 +763,8 @@ class FP16Linear(nn.Linear): class BF16Linear(nn.Linear): def __init__(self, input_features, output_features, bias=True, - mp_group=None, compute_dtype=None): + mp_group=None, compute_dtype=None, + optimize_lm_head=False): super().__init__(input_features, output_features, bias) self.in_len = input_features self.out_len = output_features @@ -752,8 +773,12 @@ class BF16Linear(nn.Linear): self.qtype = ggml_tensor_qtype["bf16"] self.mp_group = mp_group self.compute_dtype = compute_dtype + self.optimize_lm_head = optimize_lm_head def forward(self, x: torch.Tensor): + if self.optimize_lm_head: + x = reshape_lm_head_input(x) + x = x.to(torch.bfloat16) if self.weight is not None and self.weight.dtype != x.dtype: self.weight.data = self.weight.data.to(x.dtype)