LLM: support optimize_model=True for Mixtral model (#9657)
This commit is contained in:
parent
017932a7fb
commit
59ce86d292
1 changed files with 29 additions and 7 deletions
|
|
@ -605,17 +605,39 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model,
|
||||
module.AquilaRMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
elif model.config.model_type == "mistral":
|
||||
elif model.config.model_type == "mixtral":
|
||||
# For mistralai/Mixtral-8x7B-v0.1
|
||||
invalidInputError(version.parse(trans_version) >= version.parse("4.36.0"),
|
||||
"Please upgrade transformers to 4.36.0 or higher version "
|
||||
"to run Mixtral models.")
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.mistral import mistral_attention_forward
|
||||
convert_forward(model,
|
||||
module.MistralAttention,
|
||||
mistral_attention_forward
|
||||
)
|
||||
convert_forward(model,
|
||||
module.MistralRMSNorm,
|
||||
module.MixtralRMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
elif model.config.model_type == "mistral":
|
||||
if model.config.architectures is not None and \
|
||||
model.config.architectures[0] == "MixtralForCausalLM":
|
||||
# For DiscoResearch/mixtral-7b-8expert
|
||||
invalidInputError(version.parse(trans_version) >= version.parse("4.36.0"),
|
||||
"Please upgrade transformers to 4.36.0 or higher version "
|
||||
"to run Mixtral models.")
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
convert_forward(model,
|
||||
module.MistralRMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
else:
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.mistral import mistral_attention_forward
|
||||
convert_forward(model,
|
||||
module.MistralAttention,
|
||||
mistral_attention_forward
|
||||
)
|
||||
convert_forward(model,
|
||||
module.MistralRMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
elif model.config.model_type == "Yi":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
|
|||
Loading…
Reference in a new issue