[LLM] Add TransformersBgeEmbeddings class in bigdl.llm.langchain.embeddings (#10459)

* Add TransformersBgeEmbeddings class in bigdl.llm.langchain.embeddings

* Small fixes
This commit is contained in:
Yuwen Hu 2024-03-19 18:04:35 +08:00 committed by GitHub
parent 463a86cd5d
commit 72bcc27da9
2 changed files with 15 additions and 2 deletions

View file

@ -20,7 +20,7 @@
# only search the first bigdl package and end up finding only one sub-package.
from .bigdlllm import *
from .transformersembeddings import TransformersEmbeddings
from .transformersembeddings import TransformersEmbeddings, TransformersBgeEmbeddings
__all__ = [
"BigdlNativeEmbeddings",
@ -28,5 +28,6 @@ __all__ = [
"BloomEmbeddings",
"GptneoxEmbeddings",
"StarcoderEmbeddings",
"TransformersEmbeddings"
"TransformersEmbeddings",
"TransformersBgeEmbeddings"
]

View file

@ -45,6 +45,7 @@
# THE SOFTWARE.
"""Wrapper around BigdlLLM embedding models."""
import torch
from typing import Any, Dict, List, Optional
import numpy as np
@ -181,3 +182,14 @@ class TransformersEmbeddings(BaseModel, Embeddings):
text = text.replace("\n", " ")
embedding = self.embed(text, **self.encode_kwargs)
return embedding.tolist()
# fit specific encode method for langchain.embeddings.HuggingFaceBgeEmbeddings
# TODO: directly support HuggingFaceBgeEmbeddings
class TransformersBgeEmbeddings(TransformersEmbeddings):
def embed(self, text: str, **kwargs):
input_ids = self.tokenizer.encode(text, return_tensors="pt", **kwargs)
input_ids = input_ids.to(self.model.device)
embeddings = self.model(input_ids, return_dict=False)[0].cpu()
embeddings = torch.nn.functional.normalize(embeddings[:, 0], p=2, dim=1)
return embeddings[0]