diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 24864b4b..2a3b781d 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -662,7 +662,6 @@ def replace_func(m, target_m, func_name, new_func): def _optimize_ipex(model): - import intel_extension_for_pytorch as ipex from intel_extension_for_pytorch.transformers.optimize import model_convert_reference from transformers.modeling_attn_mask_utils import AttentionMaskConverter from bigdl.llm.transformers.convert_ipex import ( @@ -694,7 +693,6 @@ def _optimize_ipex(model): # baichuan2 rms_classes.append(type(model.model.layers[0].input_layernorm)) - model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval() _ipex_optimize_model(model, rms_classes) return _ipex_jit(model) diff --git a/python/llm/src/bigdl/llm/transformers/convert_ipex.py b/python/llm/src/bigdl/llm/transformers/convert_ipex.py index 067c3698..7b2ef0c4 100644 --- a/python/llm/src/bigdl/llm/transformers/convert_ipex.py +++ b/python/llm/src/bigdl/llm/transformers/convert_ipex.py @@ -41,6 +41,10 @@ from intel_extension_for_pytorch.transformers.optimize import ( lowering_class_cpu, convert_class, ) +from intel_extension_for_pytorch.cpu._auto_kernel_selection import ( + _enable_tpp, + _using_tpp, +) def _ipex_optimize_rmsnorm(_model, supported_classes): @@ -69,7 +73,7 @@ def _ipex_optimize_decoder(model): supported_mlp_class, _IPEXDecoderLayerCPU, model.config, - tpp=False, + tpp=True if _using_tpp() else False, woq=False, ) @@ -87,13 +91,15 @@ def _ipex_optimize_attention(model): supported_mha_class, _IPEXAttentionCPU, model.config, - tpp=False, + tpp=True if _using_tpp() else False, woq=False, ) def _ipex_optimize_model(model, rms_classes): - + _enable_tpp() + import intel_extension_for_pytorch as ipex + ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval() _ipex_optimize_rmsnorm(model, rms_classes) _ipex_optimize_attention(model) _ipex_optimize_decoder(model)