fix phi-2 and phi-3 convert (#11116)

This commit is contained in:
Yishuo Wang 2024-05-23 17:37:37 +08:00 committed by GitHub
parent 37b98a531f
commit 797dbc48b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1522,7 +1522,7 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.starcoder2 import model_forward from ipex_llm.transformers.models.starcoder2 import model_forward
convert_forward(model, module.Starcoder2Attention, attention_forward) convert_forward(model, module.Starcoder2Attention, attention_forward)
convert_forward(model, module.Starcoder2Model, model_forward) convert_forward(model, module.Starcoder2Model, model_forward)
elif model.config.model_type in ["phi3", "phi3_v"]: elif model.config.model_type == "phi":
# for phi-2 # for phi-2
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
@ -1530,7 +1530,7 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.phi import model_forward from ipex_llm.transformers.models.phi import model_forward
convert_forward(model, module.PhiAttention, attention_forward) convert_forward(model, module.PhiAttention, attention_forward)
convert_forward(model, module.PhiModel, model_forward) convert_forward(model, module.PhiModel, model_forward)
elif model.config.model_type == "phi3": elif model.config.model_type in ["phi3", "phi3_v"]:
# for phi-3 # for phi-3
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)