LLM : Update optimize ipex bf16 (#10038)

* use 4.35.2 and remove

* update rmsnorm

* remove

* remove

* update python style

* update

* update python style

* update

* fix style

* update

* remove whitespace
This commit is contained in:
Wang, Jian4 2024-01-31 10:59:55 +08:00 committed by GitHub
parent fb53b994f8
commit 7e5cd42a5c
2 changed files with 34 additions and 22 deletions

View file

@ -578,15 +578,16 @@ def _optimize_ipex(model):
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from bigdl.llm.transformers.convert_ipex import (
_ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
_llama_model_forward_4_35
_ipex_optimize_rmsnorm, _llama_model_forward_4_35
)
AttentionMaskConverter._make_causal_mask = _make_causal_mask
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, _llama_model_forward_4_35) # noqa
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel,
_llama_model_forward_4_35)
model = model_convert_reference(model)
_ipex_optimize_attention(model, transformers.models.llama.modeling_llama.LlamaAttention)
_ipex_optimize_decoder(model, transformers.models.llama.modeling_llama.LlamaDecoderLayer)
_ipex_optimize_rmsnorm(model)
_ipex_optimize_attention(model)
_ipex_optimize_decoder(model)
model.register_forward_hook(output_hook, with_kwargs=True)

View file

@ -75,7 +75,33 @@ def _set_optimized_model_for_generation(
return model
def _ipex_optimize_decoder(model, decoder_layer):
def _ipex_optimize_rmsnorm(_model):
from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm
import transformers
supported_classes = [
transformers.models.llama.modeling_llama.LlamaRMSNorm,
]
if _model.config.architectures[0] == "BaichuanForCausalLM":
supported_classes.append(type(_model.model.layers[0].input_layernorm))
if (
_model.config.architectures[0] == "ChatGLMModel"
and _model.config.rmsnorm
):
supported_classes.append(
type(_model.transformer.encoder.layers[0].input_layernorm)
)
for supported_class in supported_classes:
lowering_class_cpu(
_model,
supported_class,
_IPEXRMSNorm,
_model.config,
tpp=False,
woq=False,
)
def _ipex_optimize_decoder(model):
from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
_IPEXDecoderLayerRef
)
@ -91,16 +117,9 @@ def _ipex_optimize_decoder(model, decoder_layer):
tpp=False,
woq=False,
)
convert_class(
model,
decoder_layer,
_IPEXDecoderLayerRef,
model.config,
distributed=True,
)
def _ipex_optimize_attention(model, attention_layer):
def _ipex_optimize_attention(model):
from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
_IPEXAttentionRef
)
@ -116,13 +135,6 @@ def _ipex_optimize_attention(model, attention_layer):
tpp=False,
woq=False,
)
convert_class(
model,
attention_layer,
_IPEXAttentionRef,
model.config,
distributed=True,
)
def _ipex_jit(model):
@ -178,7 +190,6 @@ def _make_causal_mask(
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask