LLM : Update optimize ipex bf16 (#10038)
* use 4.35.2 and remove * update rmsnorm * remove * remove * update python style * update * update python style * update * fix style * update * remove whitespace
This commit is contained in:
		
							parent
							
								
									fb53b994f8
								
							
						
					
					
						commit
						7e5cd42a5c
					
				
					 2 changed files with 34 additions and 22 deletions
				
			
		| 
						 | 
				
			
			@ -578,15 +578,16 @@ def _optimize_ipex(model):
 | 
			
		|||
    from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 | 
			
		||||
    from bigdl.llm.transformers.convert_ipex import (
 | 
			
		||||
        _ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
 | 
			
		||||
        _llama_model_forward_4_35
 | 
			
		||||
        _ipex_optimize_rmsnorm, _llama_model_forward_4_35
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    AttentionMaskConverter._make_causal_mask = _make_causal_mask
 | 
			
		||||
    convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, _llama_model_forward_4_35)  # noqa
 | 
			
		||||
    convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel,
 | 
			
		||||
                    _llama_model_forward_4_35)
 | 
			
		||||
    model = model_convert_reference(model)
 | 
			
		||||
 | 
			
		||||
    _ipex_optimize_attention(model, transformers.models.llama.modeling_llama.LlamaAttention)
 | 
			
		||||
    _ipex_optimize_decoder(model, transformers.models.llama.modeling_llama.LlamaDecoderLayer)
 | 
			
		||||
    _ipex_optimize_rmsnorm(model)
 | 
			
		||||
    _ipex_optimize_attention(model)
 | 
			
		||||
    _ipex_optimize_decoder(model)
 | 
			
		||||
 | 
			
		||||
    model.register_forward_hook(output_hook, with_kwargs=True)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -75,7 +75,33 @@ def _set_optimized_model_for_generation(
 | 
			
		|||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _ipex_optimize_decoder(model, decoder_layer):
 | 
			
		||||
def _ipex_optimize_rmsnorm(_model):
 | 
			
		||||
    from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm
 | 
			
		||||
    import transformers
 | 
			
		||||
    supported_classes = [
 | 
			
		||||
        transformers.models.llama.modeling_llama.LlamaRMSNorm,
 | 
			
		||||
    ]
 | 
			
		||||
    if _model.config.architectures[0] == "BaichuanForCausalLM":
 | 
			
		||||
        supported_classes.append(type(_model.model.layers[0].input_layernorm))
 | 
			
		||||
    if (
 | 
			
		||||
        _model.config.architectures[0] == "ChatGLMModel"
 | 
			
		||||
        and _model.config.rmsnorm
 | 
			
		||||
    ):
 | 
			
		||||
        supported_classes.append(
 | 
			
		||||
            type(_model.transformer.encoder.layers[0].input_layernorm)
 | 
			
		||||
        )
 | 
			
		||||
    for supported_class in supported_classes:
 | 
			
		||||
        lowering_class_cpu(
 | 
			
		||||
            _model,
 | 
			
		||||
            supported_class,
 | 
			
		||||
            _IPEXRMSNorm,
 | 
			
		||||
            _model.config,
 | 
			
		||||
            tpp=False,
 | 
			
		||||
            woq=False,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _ipex_optimize_decoder(model):
 | 
			
		||||
    from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
 | 
			
		||||
        _IPEXDecoderLayerRef
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -91,16 +117,9 @@ def _ipex_optimize_decoder(model, decoder_layer):
 | 
			
		|||
            tpp=False,
 | 
			
		||||
            woq=False,
 | 
			
		||||
        )
 | 
			
		||||
    convert_class(
 | 
			
		||||
        model,
 | 
			
		||||
        decoder_layer,
 | 
			
		||||
        _IPEXDecoderLayerRef,
 | 
			
		||||
        model.config,
 | 
			
		||||
        distributed=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _ipex_optimize_attention(model, attention_layer):
 | 
			
		||||
def _ipex_optimize_attention(model):
 | 
			
		||||
    from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
 | 
			
		||||
        _IPEXAttentionRef
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -116,13 +135,6 @@ def _ipex_optimize_attention(model, attention_layer):
 | 
			
		|||
            tpp=False,
 | 
			
		||||
            woq=False,
 | 
			
		||||
        )
 | 
			
		||||
    convert_class(
 | 
			
		||||
        model,
 | 
			
		||||
        attention_layer,
 | 
			
		||||
        _IPEXAttentionRef,
 | 
			
		||||
        model.config,
 | 
			
		||||
        distributed=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _ipex_jit(model):
 | 
			
		||||
| 
						 | 
				
			
			@ -178,7 +190,6 @@ def _make_causal_mask(
 | 
			
		|||
 | 
			
		||||
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue