From 6dad8d16dff361873939333a7e02a78643b44c3f Mon Sep 17 00:00:00 2001 From: Shengsheng Huang Date: Wed, 18 Oct 2023 14:05:07 +0800 Subject: [PATCH] optimize NormHead for Baichuan2 (#9205) * optimize NormHead for Baichuan2 * fix ut and change name * rename functions --- .../llm/src/bigdl/llm/transformers/convert.py | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 08e379c0..30d9d0e3 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -121,10 +121,36 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, return model, has_been_replaced +def _optimize_pre(model): + from transformers.modeling_utils import PreTrainedModel + # All huggingface format models are inherited from `PreTrainedModel` + if not isinstance(model, PreTrainedModel): + logger.info("Only HuggingFace Transformers models are currently " + "supported for further optimizations") + return model + # process NormHead module in Baichuan2 7B and 13B + if model.config.model_type == "baichuan" and model.config.vocab_size == 125696: + # NormHead do normalization on the weights just once at inference time. + # so we do it in advance and convert it to Linear so that it can be replaced. + # modeling_module_name = model.__class__.__module__ + # module = importlib.import_module(modeling_module_name) + if hasattr(model, 'lm_head') and model.lm_head is not None: + # do we need to check the class instance? + vocab_size, hidden_size = model.lm_head.weight.shape + norm_weight = nn.functional.normalize(model.lm_head.weight.data) + model.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) + model.lm_head.weight.data = norm_weight + return model + + def ggml_convert_low_bit(model, qtype, optimize_model=True, convert_shape_only=False, device="cpu", modules_to_not_convert=None): modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert + + if optimize_model: + model = _optimize_pre(model) + model, has_been_replaced = _replace_with_low_bit_linear( model, qtype, modules_to_not_convert, None, convert_shape_only, @@ -143,7 +169,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, pass if optimize_model: - model = optimize(model) + model = _optimize_post(model) return model @@ -155,7 +181,7 @@ def convert_forward(m, target_m, new_forward): convert_forward(sub_m, target_m, new_forward) -def optimize(model): +def _optimize_post(model): from packaging import version from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31 from bigdl.llm.transformers.models.llama import llama_rms_norm_forward