Enable optimization for sentence_transformers (#10679)
* enable optimization for sentence_transformers * fix python style check failure
This commit is contained in:
parent
f03c029914
commit
dcb2038aad
1 changed files with 28 additions and 0 deletions
|
|
@ -527,6 +527,16 @@ def replace_with_low_bit_linear_for_module(model, qtype, module_name=None,
|
|||
|
||||
|
||||
def _optimize_pre(model):
|
||||
try:
|
||||
from sentence_transformers.SentenceTransformer import SentenceTransformer
|
||||
if isinstance(model, SentenceTransformer):
|
||||
if str(model._modules['0']).strip().split(' ')[-1] == 'BertModel':
|
||||
from ipex_llm.transformers.models.bert import merge_qkv
|
||||
model.apply(merge_qkv)
|
||||
return model
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
# All huggingface format models are inherited from `PreTrainedModel`
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
|
|
@ -793,6 +803,24 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
from ipex_llm.transformers.models.llama import llama_model_forward
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
try:
|
||||
from sentence_transformers.SentenceTransformer import SentenceTransformer
|
||||
if isinstance(model, SentenceTransformer):
|
||||
if str(model._modules['0']).strip().split(' ')[-1] == 'BertModel':
|
||||
modeling_module_name = model._modules['0'].auto_model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.bert import self_attention_forward
|
||||
from ipex_llm.transformers.models.bert import encoder_forward
|
||||
convert_forward(model,
|
||||
module.BertSelfAttention,
|
||||
self_attention_forward)
|
||||
convert_forward(model,
|
||||
module.BertEncoder,
|
||||
encoder_forward)
|
||||
return model
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
# All huggingface format models are inherited from `PreTrainedModel`
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
logger.info("Only HuggingFace Transformers models are currently "
|
||||
|
|
|
|||
Loading…
Reference in a new issue