LLM: support optimize_model=True for Mixtral model (#9657)
This commit is contained in:
parent
017932a7fb
commit
59ce86d292
1 changed files with 29 additions and 7 deletions
|
|
@ -605,7 +605,29 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.AquilaRMSNorm,
|
module.AquilaRMSNorm,
|
||||||
llama_rms_norm_forward)
|
llama_rms_norm_forward)
|
||||||
|
elif model.config.model_type == "mixtral":
|
||||||
|
# For mistralai/Mixtral-8x7B-v0.1
|
||||||
|
invalidInputError(version.parse(trans_version) >= version.parse("4.36.0"),
|
||||||
|
"Please upgrade transformers to 4.36.0 or higher version "
|
||||||
|
"to run Mixtral models.")
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
convert_forward(model,
|
||||||
|
module.MixtralRMSNorm,
|
||||||
|
llama_rms_norm_forward)
|
||||||
elif model.config.model_type == "mistral":
|
elif model.config.model_type == "mistral":
|
||||||
|
if model.config.architectures is not None and \
|
||||||
|
model.config.architectures[0] == "MixtralForCausalLM":
|
||||||
|
# For DiscoResearch/mixtral-7b-8expert
|
||||||
|
invalidInputError(version.parse(trans_version) >= version.parse("4.36.0"),
|
||||||
|
"Please upgrade transformers to 4.36.0 or higher version "
|
||||||
|
"to run Mixtral models.")
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
convert_forward(model,
|
||||||
|
module.MistralRMSNorm,
|
||||||
|
llama_rms_norm_forward)
|
||||||
|
else:
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.mistral import mistral_attention_forward
|
from bigdl.llm.transformers.models.mistral import mistral_attention_forward
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue