From 7e5cd42a5ca6734f788d656dfbf2220b6ba330e4 Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Wed, 31 Jan 2024 10:59:55 +0800 Subject: [PATCH] LLM : Update optimize ipex bf16 (#10038) * use 4.35.2 and remove * update rmsnorm * remove * remove * update python style * update * update python style * update * fix style * update * remove whitespace --- .../llm/src/bigdl/llm/transformers/convert.py | 11 ++--- .../bigdl/llm/transformers/convert_ipex.py | 45 ++++++++++++------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index b7614dd5..0b30b579 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -578,15 +578,16 @@ def _optimize_ipex(model): from transformers.modeling_attn_mask_utils import AttentionMaskConverter from bigdl.llm.transformers.convert_ipex import ( _ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask, - _llama_model_forward_4_35 + _ipex_optimize_rmsnorm, _llama_model_forward_4_35 ) AttentionMaskConverter._make_causal_mask = _make_causal_mask - convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, _llama_model_forward_4_35) # noqa + convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, + _llama_model_forward_4_35) model = model_convert_reference(model) - - _ipex_optimize_attention(model, transformers.models.llama.modeling_llama.LlamaAttention) - _ipex_optimize_decoder(model, transformers.models.llama.modeling_llama.LlamaDecoderLayer) + _ipex_optimize_rmsnorm(model) + _ipex_optimize_attention(model) + _ipex_optimize_decoder(model) model.register_forward_hook(output_hook, with_kwargs=True) diff --git a/python/llm/src/bigdl/llm/transformers/convert_ipex.py b/python/llm/src/bigdl/llm/transformers/convert_ipex.py index b185aba0..cd534fbe 100644 --- a/python/llm/src/bigdl/llm/transformers/convert_ipex.py +++ b/python/llm/src/bigdl/llm/transformers/convert_ipex.py @@ -75,7 +75,33 @@ def _set_optimized_model_for_generation( return model -def _ipex_optimize_decoder(model, decoder_layer): +def _ipex_optimize_rmsnorm(_model): + from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm + import transformers + supported_classes = [ + transformers.models.llama.modeling_llama.LlamaRMSNorm, + ] + if _model.config.architectures[0] == "BaichuanForCausalLM": + supported_classes.append(type(_model.model.layers[0].input_layernorm)) + if ( + _model.config.architectures[0] == "ChatGLMModel" + and _model.config.rmsnorm + ): + supported_classes.append( + type(_model.transformer.encoder.layers[0].input_layernorm) + ) + for supported_class in supported_classes: + lowering_class_cpu( + _model, + supported_class, + _IPEXRMSNorm, + _model.config, + tpp=False, + woq=False, + ) + + +def _ipex_optimize_decoder(model): from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import ( _IPEXDecoderLayerRef ) @@ -91,16 +117,9 @@ def _ipex_optimize_decoder(model, decoder_layer): tpp=False, woq=False, ) - convert_class( - model, - decoder_layer, - _IPEXDecoderLayerRef, - model.config, - distributed=True, - ) -def _ipex_optimize_attention(model, attention_layer): +def _ipex_optimize_attention(model): from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import ( _IPEXAttentionRef ) @@ -116,13 +135,6 @@ def _ipex_optimize_attention(model, attention_layer): tpp=False, woq=False, ) - convert_class( - model, - attention_layer, - _IPEXAttentionRef, - model.config, - distributed=True, - ) def _ipex_jit(model): @@ -178,7 +190,6 @@ def _make_causal_mask( return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask