diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 69e005af..6db25be0 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1522,7 +1522,7 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.starcoder2 import model_forward convert_forward(model, module.Starcoder2Attention, attention_forward) convert_forward(model, module.Starcoder2Model, model_forward) - elif model.config.model_type in ["phi3", "phi3_v"]: + elif model.config.model_type == "phi": # for phi-2 modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) @@ -1530,7 +1530,7 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.phi import model_forward convert_forward(model, module.PhiAttention, attention_forward) convert_forward(model, module.PhiModel, model_forward) - elif model.config.model_type == "phi3": + elif model.config.model_type in ["phi3", "phi3_v"]: # for phi-3 modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name)