LLM : Update optimize ipex bf16 (#10038)
* use 4.35.2 and remove * update rmsnorm * remove * remove * update python style * update * update python style * update * fix style * update * remove whitespace
This commit is contained in:
parent
fb53b994f8
commit
7e5cd42a5c
2 changed files with 34 additions and 22 deletions
|
|
@ -578,15 +578,16 @@ def _optimize_ipex(model):
|
||||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from bigdl.llm.transformers.convert_ipex import (
|
from bigdl.llm.transformers.convert_ipex import (
|
||||||
_ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
|
_ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
|
||||||
_llama_model_forward_4_35
|
_ipex_optimize_rmsnorm, _llama_model_forward_4_35
|
||||||
)
|
)
|
||||||
|
|
||||||
AttentionMaskConverter._make_causal_mask = _make_causal_mask
|
AttentionMaskConverter._make_causal_mask = _make_causal_mask
|
||||||
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, _llama_model_forward_4_35) # noqa
|
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel,
|
||||||
|
_llama_model_forward_4_35)
|
||||||
model = model_convert_reference(model)
|
model = model_convert_reference(model)
|
||||||
|
_ipex_optimize_rmsnorm(model)
|
||||||
_ipex_optimize_attention(model, transformers.models.llama.modeling_llama.LlamaAttention)
|
_ipex_optimize_attention(model)
|
||||||
_ipex_optimize_decoder(model, transformers.models.llama.modeling_llama.LlamaDecoderLayer)
|
_ipex_optimize_decoder(model)
|
||||||
|
|
||||||
model.register_forward_hook(output_hook, with_kwargs=True)
|
model.register_forward_hook(output_hook, with_kwargs=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,33 @@ def _set_optimized_model_for_generation(
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _ipex_optimize_decoder(model, decoder_layer):
|
def _ipex_optimize_rmsnorm(_model):
|
||||||
|
from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm
|
||||||
|
import transformers
|
||||||
|
supported_classes = [
|
||||||
|
transformers.models.llama.modeling_llama.LlamaRMSNorm,
|
||||||
|
]
|
||||||
|
if _model.config.architectures[0] == "BaichuanForCausalLM":
|
||||||
|
supported_classes.append(type(_model.model.layers[0].input_layernorm))
|
||||||
|
if (
|
||||||
|
_model.config.architectures[0] == "ChatGLMModel"
|
||||||
|
and _model.config.rmsnorm
|
||||||
|
):
|
||||||
|
supported_classes.append(
|
||||||
|
type(_model.transformer.encoder.layers[0].input_layernorm)
|
||||||
|
)
|
||||||
|
for supported_class in supported_classes:
|
||||||
|
lowering_class_cpu(
|
||||||
|
_model,
|
||||||
|
supported_class,
|
||||||
|
_IPEXRMSNorm,
|
||||||
|
_model.config,
|
||||||
|
tpp=False,
|
||||||
|
woq=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ipex_optimize_decoder(model):
|
||||||
from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
|
from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
|
||||||
_IPEXDecoderLayerRef
|
_IPEXDecoderLayerRef
|
||||||
)
|
)
|
||||||
|
|
@ -91,16 +117,9 @@ def _ipex_optimize_decoder(model, decoder_layer):
|
||||||
tpp=False,
|
tpp=False,
|
||||||
woq=False,
|
woq=False,
|
||||||
)
|
)
|
||||||
convert_class(
|
|
||||||
model,
|
|
||||||
decoder_layer,
|
|
||||||
_IPEXDecoderLayerRef,
|
|
||||||
model.config,
|
|
||||||
distributed=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _ipex_optimize_attention(model, attention_layer):
|
def _ipex_optimize_attention(model):
|
||||||
from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
|
from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
|
||||||
_IPEXAttentionRef
|
_IPEXAttentionRef
|
||||||
)
|
)
|
||||||
|
|
@ -116,13 +135,6 @@ def _ipex_optimize_attention(model, attention_layer):
|
||||||
tpp=False,
|
tpp=False,
|
||||||
woq=False,
|
woq=False,
|
||||||
)
|
)
|
||||||
convert_class(
|
|
||||||
model,
|
|
||||||
attention_layer,
|
|
||||||
_IPEXAttentionRef,
|
|
||||||
model.config,
|
|
||||||
distributed=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _ipex_jit(model):
|
def _ipex_jit(model):
|
||||||
|
|
@ -178,7 +190,6 @@ def _make_causal_mask(
|
||||||
|
|
||||||
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
||||||
|
|
||||||
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue