fix user issue (#12692)
This commit is contained in:
parent
68857494a5
commit
f8dc408888
1 changed files with 8 additions and 1 deletions
|
|
@ -52,7 +52,14 @@ import os
|
||||||
|
|
||||||
|
|
||||||
def _ipex_optimize_rmsnorm(_model, supported_classes, is_tpp=False, is_woq=False):
|
def _ipex_optimize_rmsnorm(_model, supported_classes, is_tpp=False, is_woq=False):
|
||||||
from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm
|
try:
|
||||||
|
# old version use name `_IPEXRMSNorm`
|
||||||
|
from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion \
|
||||||
|
import _IPEXRMSNorm
|
||||||
|
except ImportError:
|
||||||
|
# new version use name `_IPEXRMSNormCPU`
|
||||||
|
from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion \
|
||||||
|
import _IPEXRMSNormCPU as _IPEXRMSNorm
|
||||||
for supported_class in supported_classes:
|
for supported_class in supported_classes:
|
||||||
lowering_class_cpu(
|
lowering_class_cpu(
|
||||||
_model,
|
_model,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue