optimize NormHead for Baichuan2 (#9205)

* optimize NormHead for Baichuan2

* fix ut and change name

* rename functions
This commit is contained in:
Shengsheng Huang 2023-10-18 14:05:07 +08:00 committed by GitHub
parent a3b664ed03
commit 6dad8d16df

View file

@ -121,10 +121,36 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
return model, has_been_replaced
def _optimize_pre(model):
from transformers.modeling_utils import PreTrainedModel
# All huggingface format models are inherited from `PreTrainedModel`
if not isinstance(model, PreTrainedModel):
logger.info("Only HuggingFace Transformers models are currently "
"supported for further optimizations")
return model
# process NormHead module in Baichuan2 7B and 13B
if model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
# NormHead do normalization on the weights just once at inference time.
# so we do it in advance and convert it to Linear so that it can be replaced.
# modeling_module_name = model.__class__.__module__
# module = importlib.import_module(modeling_module_name)
if hasattr(model, 'lm_head') and model.lm_head is not None:
# do we need to check the class instance?
vocab_size, hidden_size = model.lm_head.weight.shape
norm_weight = nn.functional.normalize(model.lm_head.weight.data)
model.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
model.lm_head.weight.data = norm_weight
return model
def ggml_convert_low_bit(model, qtype, optimize_model=True,
convert_shape_only=False, device="cpu",
modules_to_not_convert=None):
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
if optimize_model:
model = _optimize_pre(model)
model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert,
None, convert_shape_only,
@ -143,7 +169,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
pass
if optimize_model:
model = optimize(model)
model = _optimize_post(model)
return model
@ -155,7 +181,7 @@ def convert_forward(m, target_m, new_forward):
convert_forward(sub_m, target_m, new_forward)
def optimize(model):
def _optimize_post(model):
from packaging import version
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward