fix phi-2 and phi-3 convert (#11116)
This commit is contained in:
parent
37b98a531f
commit
797dbc48b8
1 changed files with 2 additions and 2 deletions
|
|
@ -1522,7 +1522,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
from ipex_llm.transformers.models.starcoder2 import model_forward
|
||||
convert_forward(model, module.Starcoder2Attention, attention_forward)
|
||||
convert_forward(model, module.Starcoder2Model, model_forward)
|
||||
elif model.config.model_type in ["phi3", "phi3_v"]:
|
||||
elif model.config.model_type == "phi":
|
||||
# for phi-2
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
@ -1530,7 +1530,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
from ipex_llm.transformers.models.phi import model_forward
|
||||
convert_forward(model, module.PhiAttention, attention_forward)
|
||||
convert_forward(model, module.PhiModel, model_forward)
|
||||
elif model.config.model_type == "phi3":
|
||||
elif model.config.model_type in ["phi3", "phi3_v"]:
|
||||
# for phi-3
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
|
|||
Loading…
Reference in a new issue