optimize NormHead for Baichuan2 (#9205)
* optimize NormHead for Baichuan2 * fix ut and change name * rename functions
This commit is contained in:
parent
a3b664ed03
commit
6dad8d16df
1 changed files with 28 additions and 2 deletions
|
|
@ -121,10 +121,36 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
return model, has_been_replaced
|
||||
|
||||
|
||||
def _optimize_pre(model):
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
# All huggingface format models are inherited from `PreTrainedModel`
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
logger.info("Only HuggingFace Transformers models are currently "
|
||||
"supported for further optimizations")
|
||||
return model
|
||||
# process NormHead module in Baichuan2 7B and 13B
|
||||
if model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
|
||||
# NormHead do normalization on the weights just once at inference time.
|
||||
# so we do it in advance and convert it to Linear so that it can be replaced.
|
||||
# modeling_module_name = model.__class__.__module__
|
||||
# module = importlib.import_module(modeling_module_name)
|
||||
if hasattr(model, 'lm_head') and model.lm_head is not None:
|
||||
# do we need to check the class instance?
|
||||
vocab_size, hidden_size = model.lm_head.weight.shape
|
||||
norm_weight = nn.functional.normalize(model.lm_head.weight.data)
|
||||
model.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
||||
model.lm_head.weight.data = norm_weight
|
||||
return model
|
||||
|
||||
|
||||
def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||
convert_shape_only=False, device="cpu",
|
||||
modules_to_not_convert=None):
|
||||
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
|
||||
|
||||
if optimize_model:
|
||||
model = _optimize_pre(model)
|
||||
|
||||
model, has_been_replaced = _replace_with_low_bit_linear(
|
||||
model, qtype, modules_to_not_convert,
|
||||
None, convert_shape_only,
|
||||
|
|
@ -143,7 +169,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
pass
|
||||
|
||||
if optimize_model:
|
||||
model = optimize(model)
|
||||
model = _optimize_post(model)
|
||||
return model
|
||||
|
||||
|
||||
|
|
@ -155,7 +181,7 @@ def convert_forward(m, target_m, new_forward):
|
|||
convert_forward(sub_m, target_m, new_forward)
|
||||
|
||||
|
||||
def optimize(model):
|
||||
def _optimize_post(model):
|
||||
from packaging import version
|
||||
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
||||
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
|
||||
|
|
|
|||
Loading…
Reference in a new issue