fix phi-2 and phi-3 convert (#11116)
This commit is contained in:
parent
37b98a531f
commit
797dbc48b8
1 changed files with 2 additions and 2 deletions
|
|
@ -1522,7 +1522,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
from ipex_llm.transformers.models.starcoder2 import model_forward
|
from ipex_llm.transformers.models.starcoder2 import model_forward
|
||||||
convert_forward(model, module.Starcoder2Attention, attention_forward)
|
convert_forward(model, module.Starcoder2Attention, attention_forward)
|
||||||
convert_forward(model, module.Starcoder2Model, model_forward)
|
convert_forward(model, module.Starcoder2Model, model_forward)
|
||||||
elif model.config.model_type in ["phi3", "phi3_v"]:
|
elif model.config.model_type == "phi":
|
||||||
# for phi-2
|
# for phi-2
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
|
@ -1530,7 +1530,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
from ipex_llm.transformers.models.phi import model_forward
|
from ipex_llm.transformers.models.phi import model_forward
|
||||||
convert_forward(model, module.PhiAttention, attention_forward)
|
convert_forward(model, module.PhiAttention, attention_forward)
|
||||||
convert_forward(model, module.PhiModel, model_forward)
|
convert_forward(model, module.PhiModel, model_forward)
|
||||||
elif model.config.model_type == "phi3":
|
elif model.config.model_type in ["phi3", "phi3_v"]:
|
||||||
# for phi-3
|
# for phi-3
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue