LLM: Reduce speculative _ipex_optimize_model memory use (#10281)

* use tpp

* update ipex
This commit is contained in:
Wang, Jian4 2024-03-01 13:48:23 +08:00 committed by GitHub
parent f0ff0eebe1
commit beb9433cec
2 changed files with 9 additions and 5 deletions

View file

@ -662,7 +662,6 @@ def replace_func(m, target_m, func_name, new_func):
def _optimize_ipex(model):
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.transformers.optimize import model_convert_reference
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from bigdl.llm.transformers.convert_ipex import (
@ -694,7 +693,6 @@ def _optimize_ipex(model):
# baichuan2
rms_classes.append(type(model.model.layers[0].input_layernorm))
model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
_ipex_optimize_model(model, rms_classes)
return _ipex_jit(model)

View file

@ -41,6 +41,10 @@ from intel_extension_for_pytorch.transformers.optimize import (
lowering_class_cpu,
convert_class,
)
from intel_extension_for_pytorch.cpu._auto_kernel_selection import (
_enable_tpp,
_using_tpp,
)
def _ipex_optimize_rmsnorm(_model, supported_classes):
@ -69,7 +73,7 @@ def _ipex_optimize_decoder(model):
supported_mlp_class,
_IPEXDecoderLayerCPU,
model.config,
tpp=False,
tpp=True if _using_tpp() else False,
woq=False,
)
@ -87,13 +91,15 @@ def _ipex_optimize_attention(model):
supported_mha_class,
_IPEXAttentionCPU,
model.config,
tpp=False,
tpp=True if _using_tpp() else False,
woq=False,
)
def _ipex_optimize_model(model, rms_classes):
_enable_tpp()
import intel_extension_for_pytorch as ipex
ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
_ipex_optimize_rmsnorm(model, rms_classes)
_ipex_optimize_attention(model)
_ipex_optimize_decoder(model)