From dcb2038aad123e3c7a852edf6f0f74489d597217 Mon Sep 17 00:00:00 2001 From: Ovo233 <76120304+Mingyu-Wei@users.noreply.github.com> Date: Tue, 9 Apr 2024 12:33:46 +0800 Subject: [PATCH] Enable optimization for sentence_transformers (#10679) * enable optimization for sentence_transformers * fix python style check failure --- .../llm/src/ipex_llm/transformers/convert.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index e8c24db1..9372f667 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -527,6 +527,16 @@ def replace_with_low_bit_linear_for_module(model, qtype, module_name=None, def _optimize_pre(model): + try: + from sentence_transformers.SentenceTransformer import SentenceTransformer + if isinstance(model, SentenceTransformer): + if str(model._modules['0']).strip().split(' ')[-1] == 'BertModel': + from ipex_llm.transformers.models.bert import merge_qkv + model.apply(merge_qkv) + return model + except ModuleNotFoundError: + pass + from transformers.modeling_utils import PreTrainedModel # All huggingface format models are inherited from `PreTrainedModel` if not isinstance(model, PreTrainedModel): @@ -793,6 +803,24 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.llama import llama_model_forward from transformers.modeling_utils import PreTrainedModel + try: + from sentence_transformers.SentenceTransformer import SentenceTransformer + if isinstance(model, SentenceTransformer): + if str(model._modules['0']).strip().split(' ')[-1] == 'BertModel': + modeling_module_name = model._modules['0'].auto_model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from ipex_llm.transformers.models.bert import self_attention_forward + from ipex_llm.transformers.models.bert import encoder_forward + convert_forward(model, + module.BertSelfAttention, + self_attention_forward) + convert_forward(model, + module.BertEncoder, + encoder_forward) + return model + except ModuleNotFoundError: + pass + # All huggingface format models are inherited from `PreTrainedModel` if not isinstance(model, PreTrainedModel): logger.info("Only HuggingFace Transformers models are currently "