diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 02730699..3ba09b1f 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -605,17 +605,39 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.AquilaRMSNorm, llama_rms_norm_forward) - elif model.config.model_type == "mistral": + elif model.config.model_type == "mixtral": + # For mistralai/Mixtral-8x7B-v0.1 + invalidInputError(version.parse(trans_version) >= version.parse("4.36.0"), + "Please upgrade transformers to 4.36.0 or higher version " + "to run Mixtral models.") modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) - from bigdl.llm.transformers.models.mistral import mistral_attention_forward convert_forward(model, - module.MistralAttention, - mistral_attention_forward - ) - convert_forward(model, - module.MistralRMSNorm, + module.MixtralRMSNorm, llama_rms_norm_forward) + elif model.config.model_type == "mistral": + if model.config.architectures is not None and \ + model.config.architectures[0] == "MixtralForCausalLM": + # For DiscoResearch/mixtral-7b-8expert + invalidInputError(version.parse(trans_version) >= version.parse("4.36.0"), + "Please upgrade transformers to 4.36.0 or higher version " + "to run Mixtral models.") + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + convert_forward(model, + module.MistralRMSNorm, + llama_rms_norm_forward) + else: + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.mistral import mistral_attention_forward + convert_forward(model, + module.MistralAttention, + mistral_attention_forward + ) + convert_forward(model, + module.MistralRMSNorm, + llama_rms_norm_forward) elif model.config.model_type == "Yi": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name)