LLM : Update optimize ipex bf16 (#10038)
* use 4.35.2 and remove * update rmsnorm * remove * remove * update python style * update * update python style * update * fix style * update * remove whitespace
This commit is contained in:
		
							parent
							
								
									fb53b994f8
								
							
						
					
					
						commit
						7e5cd42a5c
					
				
					 2 changed files with 34 additions and 22 deletions
				
			
		| 
						 | 
					@ -578,15 +578,16 @@ def _optimize_ipex(model):
 | 
				
			||||||
    from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 | 
					    from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 | 
				
			||||||
    from bigdl.llm.transformers.convert_ipex import (
 | 
					    from bigdl.llm.transformers.convert_ipex import (
 | 
				
			||||||
        _ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
 | 
					        _ipex_optimize_attention, _ipex_optimize_decoder, _ipex_jit, _make_causal_mask,
 | 
				
			||||||
        _llama_model_forward_4_35
 | 
					        _ipex_optimize_rmsnorm, _llama_model_forward_4_35
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    AttentionMaskConverter._make_causal_mask = _make_causal_mask
 | 
					    AttentionMaskConverter._make_causal_mask = _make_causal_mask
 | 
				
			||||||
    convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel, _llama_model_forward_4_35)  # noqa
 | 
					    convert_forward(model, transformers.models.llama.modeling_llama.LlamaModel,
 | 
				
			||||||
 | 
					                    _llama_model_forward_4_35)
 | 
				
			||||||
    model = model_convert_reference(model)
 | 
					    model = model_convert_reference(model)
 | 
				
			||||||
 | 
					    _ipex_optimize_rmsnorm(model)
 | 
				
			||||||
    _ipex_optimize_attention(model, transformers.models.llama.modeling_llama.LlamaAttention)
 | 
					    _ipex_optimize_attention(model)
 | 
				
			||||||
    _ipex_optimize_decoder(model, transformers.models.llama.modeling_llama.LlamaDecoderLayer)
 | 
					    _ipex_optimize_decoder(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model.register_forward_hook(output_hook, with_kwargs=True)
 | 
					    model.register_forward_hook(output_hook, with_kwargs=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -75,7 +75,33 @@ def _set_optimized_model_for_generation(
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _ipex_optimize_decoder(model, decoder_layer):
 | 
					def _ipex_optimize_rmsnorm(_model):
 | 
				
			||||||
 | 
					    from intel_extension_for_pytorch.transformers.models.cpu.fusions.mha_fusion import _IPEXRMSNorm
 | 
				
			||||||
 | 
					    import transformers
 | 
				
			||||||
 | 
					    supported_classes = [
 | 
				
			||||||
 | 
					        transformers.models.llama.modeling_llama.LlamaRMSNorm,
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    if _model.config.architectures[0] == "BaichuanForCausalLM":
 | 
				
			||||||
 | 
					        supported_classes.append(type(_model.model.layers[0].input_layernorm))
 | 
				
			||||||
 | 
					    if (
 | 
				
			||||||
 | 
					        _model.config.architectures[0] == "ChatGLMModel"
 | 
				
			||||||
 | 
					        and _model.config.rmsnorm
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        supported_classes.append(
 | 
				
			||||||
 | 
					            type(_model.transformer.encoder.layers[0].input_layernorm)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    for supported_class in supported_classes:
 | 
				
			||||||
 | 
					        lowering_class_cpu(
 | 
				
			||||||
 | 
					            _model,
 | 
				
			||||||
 | 
					            supported_class,
 | 
				
			||||||
 | 
					            _IPEXRMSNorm,
 | 
				
			||||||
 | 
					            _model.config,
 | 
				
			||||||
 | 
					            tpp=False,
 | 
				
			||||||
 | 
					            woq=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _ipex_optimize_decoder(model):
 | 
				
			||||||
    from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
 | 
					    from intel_extension_for_pytorch.transformers.models.reference.modules.decoder import (
 | 
				
			||||||
        _IPEXDecoderLayerRef
 | 
					        _IPEXDecoderLayerRef
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					@ -91,16 +117,9 @@ def _ipex_optimize_decoder(model, decoder_layer):
 | 
				
			||||||
            tpp=False,
 | 
					            tpp=False,
 | 
				
			||||||
            woq=False,
 | 
					            woq=False,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    convert_class(
 | 
					 | 
				
			||||||
        model,
 | 
					 | 
				
			||||||
        decoder_layer,
 | 
					 | 
				
			||||||
        _IPEXDecoderLayerRef,
 | 
					 | 
				
			||||||
        model.config,
 | 
					 | 
				
			||||||
        distributed=True,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _ipex_optimize_attention(model, attention_layer):
 | 
					def _ipex_optimize_attention(model):
 | 
				
			||||||
    from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
 | 
					    from intel_extension_for_pytorch.transformers.models.reference.modules.attentions import (
 | 
				
			||||||
        _IPEXAttentionRef
 | 
					        _IPEXAttentionRef
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					@ -116,13 +135,6 @@ def _ipex_optimize_attention(model, attention_layer):
 | 
				
			||||||
            tpp=False,
 | 
					            tpp=False,
 | 
				
			||||||
            woq=False,
 | 
					            woq=False,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    convert_class(
 | 
					 | 
				
			||||||
        model,
 | 
					 | 
				
			||||||
        attention_layer,
 | 
					 | 
				
			||||||
        _IPEXAttentionRef,
 | 
					 | 
				
			||||||
        model.config,
 | 
					 | 
				
			||||||
        distributed=True,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _ipex_jit(model):
 | 
					def _ipex_jit(model):
 | 
				
			||||||
| 
						 | 
					@ -178,7 +190,6 @@ def _make_causal_mask(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
 | 
					    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
					from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
				
			||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 | 
					from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue