LLM: support optimize_model=True for Mixtral model (#9657)

This commit is contained in:
binbin Deng 2023-12-12 16:41:26 +08:00 committed by GitHub
parent 017932a7fb
commit 59ce86d292

View file

@ -605,17 +605,39 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.AquilaRMSNorm,
llama_rms_norm_forward)
elif model.config.model_type == "mistral":
elif model.config.model_type == "mixtral":
# For mistralai/Mixtral-8x7B-v0.1
invalidInputError(version.parse(trans_version) >= version.parse("4.36.0"),
"Please upgrade transformers to 4.36.0 or higher version "
"to run Mixtral models.")
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.mistral import mistral_attention_forward
convert_forward(model,
module.MistralAttention,
mistral_attention_forward
)
convert_forward(model,
module.MistralRMSNorm,
module.MixtralRMSNorm,
llama_rms_norm_forward)
elif model.config.model_type == "mistral":
if model.config.architectures is not None and \
model.config.architectures[0] == "MixtralForCausalLM":
# For DiscoResearch/mixtral-7b-8expert
invalidInputError(version.parse(trans_version) >= version.parse("4.36.0"),
"Please upgrade transformers to 4.36.0 or higher version "
"to run Mixtral models.")
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
convert_forward(model,
module.MistralRMSNorm,
llama_rms_norm_forward)
else:
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.mistral import mistral_attention_forward
convert_forward(model,
module.MistralAttention,
mistral_attention_forward
)
convert_forward(model,
module.MistralRMSNorm,
llama_rms_norm_forward)
elif model.config.model_type == "Yi":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)