LLM: Reduce speculative _ipex_optimize_model memory use (#10281)
* use tpp * update ipex
This commit is contained in:
parent
f0ff0eebe1
commit
beb9433cec
2 changed files with 9 additions and 5 deletions
|
|
@ -662,7 +662,6 @@ def replace_func(m, target_m, func_name, new_func):
|
||||||
|
|
||||||
|
|
||||||
def _optimize_ipex(model):
|
def _optimize_ipex(model):
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
from intel_extension_for_pytorch.transformers.optimize import model_convert_reference
|
from intel_extension_for_pytorch.transformers.optimize import model_convert_reference
|
||||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from bigdl.llm.transformers.convert_ipex import (
|
from bigdl.llm.transformers.convert_ipex import (
|
||||||
|
|
@ -694,7 +693,6 @@ def _optimize_ipex(model):
|
||||||
# baichuan2
|
# baichuan2
|
||||||
rms_classes.append(type(model.model.layers[0].input_layernorm))
|
rms_classes.append(type(model.model.layers[0].input_layernorm))
|
||||||
|
|
||||||
model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
|
|
||||||
_ipex_optimize_model(model, rms_classes)
|
_ipex_optimize_model(model, rms_classes)
|
||||||
return _ipex_jit(model)
|
return _ipex_jit(model)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,10 @@ from intel_extension_for_pytorch.transformers.optimize import (
|
||||||
lowering_class_cpu,
|
lowering_class_cpu,
|
||||||
convert_class,
|
convert_class,
|
||||||
)
|
)
|
||||||
|
from intel_extension_for_pytorch.cpu._auto_kernel_selection import (
|
||||||
|
_enable_tpp,
|
||||||
|
_using_tpp,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _ipex_optimize_rmsnorm(_model, supported_classes):
|
def _ipex_optimize_rmsnorm(_model, supported_classes):
|
||||||
|
|
@ -69,7 +73,7 @@ def _ipex_optimize_decoder(model):
|
||||||
supported_mlp_class,
|
supported_mlp_class,
|
||||||
_IPEXDecoderLayerCPU,
|
_IPEXDecoderLayerCPU,
|
||||||
model.config,
|
model.config,
|
||||||
tpp=False,
|
tpp=True if _using_tpp() else False,
|
||||||
woq=False,
|
woq=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -87,13 +91,15 @@ def _ipex_optimize_attention(model):
|
||||||
supported_mha_class,
|
supported_mha_class,
|
||||||
_IPEXAttentionCPU,
|
_IPEXAttentionCPU,
|
||||||
model.config,
|
model.config,
|
||||||
tpp=False,
|
tpp=True if _using_tpp() else False,
|
||||||
woq=False,
|
woq=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _ipex_optimize_model(model, rms_classes):
|
def _ipex_optimize_model(model, rms_classes):
|
||||||
|
_enable_tpp()
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
|
||||||
_ipex_optimize_rmsnorm(model, rms_classes)
|
_ipex_optimize_rmsnorm(model, rms_classes)
|
||||||
_ipex_optimize_attention(model)
|
_ipex_optimize_attention(model)
|
||||||
_ipex_optimize_decoder(model)
|
_ipex_optimize_decoder(model)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue