LLM : Update optimize ipex bf16 (#10038)

* use 4.35.2 and remove

* update rmsnorm

* remove

* remove

* update python style

* update

* update python style

* update

* fix style

* update

* remove whitespace
This commit is contained in:
Wang, Jian4 2024-01-31 10:59:55 +08:00 committed by GitHub
parent fb53b994f8
commit 7e5cd42a5c
2 changed files with 34 additions and 22 deletions

View file

@ -578,15 +578,16 @@ def _optimize_ipex(model):
from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from bigdl.llm.transformers.convert_ipex import ( from bigdl.llm.transformers.convert_ipex import (
_ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask, _ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
_llama_model_forward_4_35 _ipex_optimize_rmsnorm, _llama_model_forward_4_35
) )
AttentionMaskConverter._make_causal_mask = _make_causal_mask AttentionMaskConverter._make_causal_mask = _make_causal_mask
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, _llama_model_forward_4_35) # noqa convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel,
_llama_model_forward_4_35)
model = model_convert_reference(model) model = model_convert_reference(model)
_ipex_optimize_rmsnorm(model)
_ipex_optimize_attention(model, transformers.models.llama.modeling_llama.LlamaAttention) _ipex_optimize_attention(model)
_ipex_optimize_decoder(model, transformers.models.llama.modeling_llama.LlamaDecoderLayer) _ipex_optimize_decoder(model)
model.register_forward_hook(output_hook, with_kwargs=True) model.register_forward_hook(output_hook, with_kwargs=True)

View file

@ -75,7 +75,33 @@ def _set_optimized_model_for_generation(
return model return model
def _ipex_optimize_decoder(model, decoder_layer): def _ipex_optimize_rmsnorm(_model):
from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm
import transformers
supported_classes = [
transformers.models.llama.modeling_llama.LlamaRMSNorm,
]
if _model.config.architectures[0] == "BaichuanForCausalLM":
supported_classes.append(type(_model.model.layers[0].input_layernorm))
if (
_model.config.architectures[0] == "ChatGLMModel"
and _model.config.rmsnorm
):
supported_classes.append(
type(_model.transformer.encoder.layers[0].input_layernorm)
)
for supported_class in supported_classes:
lowering_class_cpu(
_model,
supported_class,
_IPEXRMSNorm,
_model.config,
tpp=False,
woq=False,
)
def _ipex_optimize_decoder(model):
from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import ( from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
_IPEXDecoderLayerRef _IPEXDecoderLayerRef
) )
@ -91,16 +117,9 @@ def _ipex_optimize_decoder(model, decoder_layer):
tpp=False, tpp=False,
woq=False, woq=False,
) )
convert_class(
model,
decoder_layer,
_IPEXDecoderLayerRef,
model.config,
distributed=True,
)
def _ipex_optimize_attention(model, attention_layer): def _ipex_optimize_attention(model):
from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import ( from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
_IPEXAttentionRef _IPEXAttentionRef
) )
@ -116,13 +135,6 @@ def _ipex_optimize_attention(model, attention_layer):
tpp=False, tpp=False,
woq=False, woq=False,
) )
convert_class(
model,
attention_layer,
_IPEXAttentionRef,
model.config,
distributed=True,
)
def _ipex_jit(model): def _ipex_jit(model):
@ -178,7 +190,6 @@ def _make_causal_mask(
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask