LLM: Reduce speculative _ipex_optimize_model memory use (#10281)
* use tpp * update ipex
This commit is contained in:
		
							parent
							
								
									f0ff0eebe1
								
							
						
					
					
						commit
						beb9433cec
					
				
					 2 changed files with 9 additions and 5 deletions
				
			
		| 
						 | 
				
			
			@ -662,7 +662,6 @@ def replace_func(m, target_m, func_name, new_func):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def _optimize_ipex(model):
 | 
			
		||||
    import intel_extension_for_pytorch as ipex
 | 
			
		||||
    from intel_extension_for_pytorch.transformers.optimize import model_convert_reference
 | 
			
		||||
    from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 | 
			
		||||
    from bigdl.llm.transformers.convert_ipex import (
 | 
			
		||||
| 
						 | 
				
			
			@ -694,7 +693,6 @@ def _optimize_ipex(model):
 | 
			
		|||
        # baichuan2
 | 
			
		||||
        rms_classes.append(type(model.model.layers[0].input_layernorm))
 | 
			
		||||
 | 
			
		||||
    model = ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
 | 
			
		||||
    _ipex_optimize_model(model, rms_classes)
 | 
			
		||||
    return _ipex_jit(model)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -41,6 +41,10 @@ from intel_extension_for_pytorch.transformers.optimize import (
 | 
			
		|||
    lowering_class_cpu,
 | 
			
		||||
    convert_class,
 | 
			
		||||
)
 | 
			
		||||
from intel_extension_for_pytorch.cpu._auto_kernel_selection import (
 | 
			
		||||
    _enable_tpp,
 | 
			
		||||
    _using_tpp,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _ipex_optimize_rmsnorm(_model, supported_classes):
 | 
			
		||||
| 
						 | 
				
			
			@ -69,7 +73,7 @@ def _ipex_optimize_decoder(model):
 | 
			
		|||
            supported_mlp_class,
 | 
			
		||||
            _IPEXDecoderLayerCPU,
 | 
			
		||||
            model.config,
 | 
			
		||||
            tpp=False,
 | 
			
		||||
            tpp=True if _using_tpp() else False,
 | 
			
		||||
            woq=False,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -87,13 +91,15 @@ def _ipex_optimize_attention(model):
 | 
			
		|||
            supported_mha_class,
 | 
			
		||||
            _IPEXAttentionCPU,
 | 
			
		||||
            model.config,
 | 
			
		||||
            tpp=False,
 | 
			
		||||
            tpp=True if _using_tpp() else False,
 | 
			
		||||
            woq=False,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _ipex_optimize_model(model, rms_classes):
 | 
			
		||||
 | 
			
		||||
    _enable_tpp()
 | 
			
		||||
    import intel_extension_for_pytorch as ipex
 | 
			
		||||
    ipex.optimize(model.eval(), dtype=torch.bfloat16, inplace=True).eval()
 | 
			
		||||
    _ipex_optimize_rmsnorm(model, rms_classes)
 | 
			
		||||
    _ipex_optimize_attention(model)
 | 
			
		||||
    _ipex_optimize_decoder(model)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue