diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index e8c24db1..9372f667 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -527,6 +527,16 @@ def replace_with_low_bit_linear_for_module(model, qtype, module_name=None, def _optimize_pre(model): + try: + from sentence_transformers.SentenceTransformer import SentenceTransformer + if isinstance(model, SentenceTransformer): + if str(model._modules['0']).strip().split(' ')[-1] == 'BertModel': + from ipex_llm.transformers.models.bert import merge_qkv + model.apply(merge_qkv) + return model + except ModuleNotFoundError: + pass + from transformers.modeling_utils import PreTrainedModel # All huggingface format models are inherited from `PreTrainedModel` if not isinstance(model, PreTrainedModel): @@ -793,6 +803,24 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.llama import llama_model_forward from transformers.modeling_utils import PreTrainedModel + try: + from sentence_transformers.SentenceTransformer import SentenceTransformer + if isinstance(model, SentenceTransformer): + if str(model._modules['0']).strip().split(' ')[-1] == 'BertModel': + modeling_module_name = model._modules['0'].auto_model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from ipex_llm.transformers.models.bert import self_attention_forward + from ipex_llm.transformers.models.bert import encoder_forward + convert_forward(model, + module.BertSelfAttention, + self_attention_forward) + convert_forward(model, + module.BertEncoder, + encoder_forward) + return model + except ModuleNotFoundError: + pass + # All huggingface format models are inherited from `PreTrainedModel` if not isinstance(model, PreTrainedModel): logger.info("Only HuggingFace Transformers models are currently "