[LLM] Add TransformersBgeEmbeddings class in bigdl.llm.langchain.embeddings (#10459)
* Add TransformersBgeEmbeddings class in bigdl.llm.langchain.embeddings * Small fixes
This commit is contained in:
parent
463a86cd5d
commit
72bcc27da9
2 changed files with 15 additions and 2 deletions
|
|
@ -20,7 +20,7 @@
|
|||
# only search the first bigdl package and end up finding only one sub-package.
|
||||
|
||||
from .bigdlllm import *
|
||||
from .transformersembeddings import TransformersEmbeddings
|
||||
from .transformersembeddings import TransformersEmbeddings, TransformersBgeEmbeddings
|
||||
|
||||
__all__ = [
|
||||
"BigdlNativeEmbeddings",
|
||||
|
|
@ -28,5 +28,6 @@ __all__ = [
|
|||
"BloomEmbeddings",
|
||||
"GptneoxEmbeddings",
|
||||
"StarcoderEmbeddings",
|
||||
"TransformersEmbeddings"
|
||||
"TransformersEmbeddings",
|
||||
"TransformersBgeEmbeddings"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@
|
|||
# THE SOFTWARE.
|
||||
|
||||
"""Wrapper around BigdlLLM embedding models."""
|
||||
import torch
|
||||
from typing import Any, Dict, List, Optional
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -181,3 +182,14 @@ class TransformersEmbeddings(BaseModel, Embeddings):
|
|||
text = text.replace("\n", " ")
|
||||
embedding = self.embed(text, **self.encode_kwargs)
|
||||
return embedding.tolist()
|
||||
|
||||
# fit specific encode method for langchain.embeddings.HuggingFaceBgeEmbeddings
|
||||
# TODO: directly support HuggingFaceBgeEmbeddings
|
||||
class TransformersBgeEmbeddings(TransformersEmbeddings):
|
||||
|
||||
def embed(self, text: str, **kwargs):
|
||||
input_ids = self.tokenizer.encode(text, return_tensors="pt", **kwargs)
|
||||
input_ids = input_ids.to(self.model.device)
|
||||
embeddings = self.model(input_ids, return_dict=False)[0].cpu()
|
||||
embeddings = torch.nn.functional.normalize(embeddings[:, 0], p=2, dim=1)
|
||||
return embeddings[0]
|
||||
|
|
|
|||
Loading…
Reference in a new issue