Enable optimization for sentence_transformers (#10679)
* enable optimization for sentence_transformers * fix python style check failure
This commit is contained in:
		
							parent
							
								
									f03c029914
								
							
						
					
					
						commit
						dcb2038aad
					
				
					 1 changed files with 28 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -527,6 +527,16 @@ def replace_with_low_bit_linear_for_module(model, qtype, module_name=None,
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def _optimize_pre(model):
 | 
			
		||||
    try:
 | 
			
		||||
        from sentence_transformers.SentenceTransformer import SentenceTransformer
 | 
			
		||||
        if isinstance(model, SentenceTransformer):
 | 
			
		||||
            if str(model._modules['0']).strip().split(' ')[-1] == 'BertModel':
 | 
			
		||||
                from ipex_llm.transformers.models.bert import merge_qkv
 | 
			
		||||
                model.apply(merge_qkv)
 | 
			
		||||
                return model
 | 
			
		||||
    except ModuleNotFoundError:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
    # All huggingface format models are inherited from `PreTrainedModel`
 | 
			
		||||
    if not isinstance(model, PreTrainedModel):
 | 
			
		||||
| 
						 | 
				
			
			@ -793,6 +803,24 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
    from ipex_llm.transformers.models.llama import llama_model_forward
 | 
			
		||||
    from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        from sentence_transformers.SentenceTransformer import SentenceTransformer
 | 
			
		||||
        if isinstance(model, SentenceTransformer):
 | 
			
		||||
            if str(model._modules['0']).strip().split(' ')[-1] == 'BertModel':
 | 
			
		||||
                modeling_module_name = model._modules['0'].auto_model.__class__.__module__
 | 
			
		||||
                module = importlib.import_module(modeling_module_name)
 | 
			
		||||
                from ipex_llm.transformers.models.bert import self_attention_forward
 | 
			
		||||
                from ipex_llm.transformers.models.bert import encoder_forward
 | 
			
		||||
                convert_forward(model,
 | 
			
		||||
                                module.BertSelfAttention,
 | 
			
		||||
                                self_attention_forward)
 | 
			
		||||
                convert_forward(model,
 | 
			
		||||
                                module.BertEncoder,
 | 
			
		||||
                                encoder_forward)
 | 
			
		||||
                return model
 | 
			
		||||
    except ModuleNotFoundError:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    # All huggingface format models are inherited from `PreTrainedModel`
 | 
			
		||||
    if not isinstance(model, PreTrainedModel):
 | 
			
		||||
        logger.info("Only HuggingFace Transformers models are currently "
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue