[LLM] Add TransformersBgeEmbeddings class in bigdl.llm.langchain.embeddings (#10459)

* Add TransformersBgeEmbeddings class in bigdl.llm.langchain.embeddings

* Small fixes
This commit is contained in:
Yuwen Hu 2024-03-19 18:04:35 +08:00 committed by GitHub
parent 463a86cd5d
commit 72bcc27da9
2 changed files with 15 additions and 2 deletions

View file

@ -20,7 +20,7 @@
# only search the first bigdl package and end up finding only one sub-package. # only search the first bigdl package and end up finding only one sub-package.
from .bigdlllm import * from .bigdlllm import *
from .transformersembeddings import TransformersEmbeddings from .transformersembeddings import TransformersEmbeddings, TransformersBgeEmbeddings
__all__ = [ __all__ = [
"BigdlNativeEmbeddings", "BigdlNativeEmbeddings",
@ -28,5 +28,6 @@ __all__ = [
"BloomEmbeddings", "BloomEmbeddings",
"GptneoxEmbeddings", "GptneoxEmbeddings",
"StarcoderEmbeddings", "StarcoderEmbeddings",
"TransformersEmbeddings" "TransformersEmbeddings",
"TransformersBgeEmbeddings"
] ]

View file

@ -45,6 +45,7 @@
# THE SOFTWARE. # THE SOFTWARE.
"""Wrapper around BigdlLLM embedding models.""" """Wrapper around BigdlLLM embedding models."""
import torch
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import numpy as np import numpy as np
@ -181,3 +182,14 @@ class TransformersEmbeddings(BaseModel, Embeddings):
text = text.replace("\n", " ") text = text.replace("\n", " ")
embedding = self.embed(text, **self.encode_kwargs) embedding = self.embed(text, **self.encode_kwargs)
return embedding.tolist() return embedding.tolist()
# fit specific encode method for langchain.embeddings.HuggingFaceBgeEmbeddings
# TODO: directly support HuggingFaceBgeEmbeddings
class TransformersBgeEmbeddings(TransformersEmbeddings):
def embed(self, text: str, **kwargs):
input_ids = self.tokenizer.encode(text, return_tensors="pt", **kwargs)
input_ids = input_ids.to(self.model.device)
embeddings = self.model(input_ids, return_dict=False)[0].cpu()
embeddings = torch.nn.functional.normalize(embeddings[:, 0], p=2, dim=1)
return embeddings[0]