optimize NormHead for Baichuan2 (#9205)
* optimize NormHead for Baichuan2 * fix ut and change name * rename functions
This commit is contained in:
parent
a3b664ed03
commit
6dad8d16df
1 changed files with 28 additions and 2 deletions
|
|
@ -121,10 +121,36 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
return model, has_been_replaced
|
return model, has_been_replaced
|
||||||
|
|
||||||
|
|
||||||
|
def _optimize_pre(model):
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
# All huggingface format models are inherited from `PreTrainedModel`
|
||||||
|
if not isinstance(model, PreTrainedModel):
|
||||||
|
logger.info("Only HuggingFace Transformers models are currently "
|
||||||
|
"supported for further optimizations")
|
||||||
|
return model
|
||||||
|
# process NormHead module in Baichuan2 7B and 13B
|
||||||
|
if model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
|
||||||
|
# NormHead do normalization on the weights just once at inference time.
|
||||||
|
# so we do it in advance and convert it to Linear so that it can be replaced.
|
||||||
|
# modeling_module_name = model.__class__.__module__
|
||||||
|
# module = importlib.import_module(modeling_module_name)
|
||||||
|
if hasattr(model, 'lm_head') and model.lm_head is not None:
|
||||||
|
# do we need to check the class instance?
|
||||||
|
vocab_size, hidden_size = model.lm_head.weight.shape
|
||||||
|
norm_weight = nn.functional.normalize(model.lm_head.weight.data)
|
||||||
|
model.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
||||||
|
model.lm_head.weight.data = norm_weight
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
convert_shape_only=False, device="cpu",
|
convert_shape_only=False, device="cpu",
|
||||||
modules_to_not_convert=None):
|
modules_to_not_convert=None):
|
||||||
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
|
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
|
||||||
|
|
||||||
|
if optimize_model:
|
||||||
|
model = _optimize_pre(model)
|
||||||
|
|
||||||
model, has_been_replaced = _replace_with_low_bit_linear(
|
model, has_been_replaced = _replace_with_low_bit_linear(
|
||||||
model, qtype, modules_to_not_convert,
|
model, qtype, modules_to_not_convert,
|
||||||
None, convert_shape_only,
|
None, convert_shape_only,
|
||||||
|
|
@ -143,7 +169,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if optimize_model:
|
if optimize_model:
|
||||||
model = optimize(model)
|
model = _optimize_post(model)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -155,7 +181,7 @@ def convert_forward(m, target_m, new_forward):
|
||||||
convert_forward(sub_m, target_m, new_forward)
|
convert_forward(sub_m, target_m, new_forward)
|
||||||
|
|
||||||
|
|
||||||
def optimize(model):
|
def _optimize_post(model):
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
from bigdl.llm.transformers.models.llama import llama_attention_forward_4_31
|
||||||
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
|
from bigdl.llm.transformers.models.llama import llama_rms_norm_forward
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue