LLM : Update optimize ipex bf16 (#10038)
* use 4.35.2 and remove * update rmsnorm * remove * remove * update python style * update * update python style * update * fix style * update * remove whitespace
This commit is contained in:
parent
fb53b994f8
commit
7e5cd42a5c
2 changed files with 34 additions and 22 deletions
|
|
@ -578,15 +578,16 @@ def _optimize_ipex(model):
|
|||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from bigdl.llm.transformers.convert_ipex import (
|
||||
_ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
|
||||
_llama_model_forward_4_35
|
||||
_ipex_optimize_rmsnorm, _llama_model_forward_4_35
|
||||
)
|
||||
|
||||
AttentionMaskConverter._make_causal_mask = _make_causal_mask
|
||||
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, _llama_model_forward_4_35) # noqa
|
||||
convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel,
|
||||
_llama_model_forward_4_35)
|
||||
model = model_convert_reference(model)
|
||||
|
||||
_ipex_optimize_attention(model, transformers.models.llama.modeling_llama.LlamaAttention)
|
||||
_ipex_optimize_decoder(model, transformers.models.llama.modeling_llama.LlamaDecoderLayer)
|
||||
_ipex_optimize_rmsnorm(model)
|
||||
_ipex_optimize_attention(model)
|
||||
_ipex_optimize_decoder(model)
|
||||
|
||||
model.register_forward_hook(output_hook, with_kwargs=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -75,7 +75,33 @@ def _set_optimized_model_for_generation(
|
|||
return model
|
||||
|
||||
|
||||
def _ipex_optimize_decoder(model, decoder_layer):
|
||||
def _ipex_optimize_rmsnorm(_model):
|
||||
from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm
|
||||
import transformers
|
||||
supported_classes = [
|
||||
transformers.models.llama.modeling_llama.LlamaRMSNorm,
|
||||
]
|
||||
if _model.config.architectures[0] == "BaichuanForCausalLM":
|
||||
supported_classes.append(type(_model.model.layers[0].input_layernorm))
|
||||
if (
|
||||
_model.config.architectures[0] == "ChatGLMModel"
|
||||
and _model.config.rmsnorm
|
||||
):
|
||||
supported_classes.append(
|
||||
type(_model.transformer.encoder.layers[0].input_layernorm)
|
||||
)
|
||||
for supported_class in supported_classes:
|
||||
lowering_class_cpu(
|
||||
_model,
|
||||
supported_class,
|
||||
_IPEXRMSNorm,
|
||||
_model.config,
|
||||
tpp=False,
|
||||
woq=False,
|
||||
)
|
||||
|
||||
|
||||
def _ipex_optimize_decoder(model):
|
||||
from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
|
||||
_IPEXDecoderLayerRef
|
||||
)
|
||||
|
|
@ -91,16 +117,9 @@ def _ipex_optimize_decoder(model, decoder_layer):
|
|||
tpp=False,
|
||||
woq=False,
|
||||
)
|
||||
convert_class(
|
||||
model,
|
||||
decoder_layer,
|
||||
_IPEXDecoderLayerRef,
|
||||
model.config,
|
||||
distributed=True,
|
||||
)
|
||||
|
||||
|
||||
def _ipex_optimize_attention(model, attention_layer):
|
||||
def _ipex_optimize_attention(model):
|
||||
from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
|
||||
_IPEXAttentionRef
|
||||
)
|
||||
|
|
@ -116,13 +135,6 @@ def _ipex_optimize_attention(model, attention_layer):
|
|||
tpp=False,
|
||||
woq=False,
|
||||
)
|
||||
convert_class(
|
||||
model,
|
||||
attention_layer,
|
||||
_IPEXAttentionRef,
|
||||
model.config,
|
||||
distributed=True,
|
||||
)
|
||||
|
||||
|
||||
def _ipex_jit(model):
|
||||
|
|
@ -178,7 +190,6 @@ def _make_causal_mask(
|
|||
|
||||
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
||||
|
||||
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue