[LLM] Add TransformersBgeEmbeddings class in bigdl.llm.langchain.embeddings (#10459)
* Add TransformersBgeEmbeddings class in bigdl.llm.langchain.embeddings * Small fixes
This commit is contained in:
parent
463a86cd5d
commit
72bcc27da9
2 changed files with 15 additions and 2 deletions
|
|
@ -20,7 +20,7 @@
|
||||||
# only search the first bigdl package and end up finding only one sub-package.
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
from .bigdlllm import *
|
from .bigdlllm import *
|
||||||
from .transformersembeddings import TransformersEmbeddings
|
from .transformersembeddings import TransformersEmbeddings, TransformersBgeEmbeddings
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BigdlNativeEmbeddings",
|
"BigdlNativeEmbeddings",
|
||||||
|
|
@ -28,5 +28,6 @@ __all__ = [
|
||||||
"BloomEmbeddings",
|
"BloomEmbeddings",
|
||||||
"GptneoxEmbeddings",
|
"GptneoxEmbeddings",
|
||||||
"StarcoderEmbeddings",
|
"StarcoderEmbeddings",
|
||||||
"TransformersEmbeddings"
|
"TransformersEmbeddings",
|
||||||
|
"TransformersBgeEmbeddings"
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@
|
||||||
# THE SOFTWARE.
|
# THE SOFTWARE.
|
||||||
|
|
||||||
"""Wrapper around BigdlLLM embedding models."""
|
"""Wrapper around BigdlLLM embedding models."""
|
||||||
|
import torch
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -181,3 +182,14 @@ class TransformersEmbeddings(BaseModel, Embeddings):
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
embedding = self.embed(text, **self.encode_kwargs)
|
embedding = self.embed(text, **self.encode_kwargs)
|
||||||
return embedding.tolist()
|
return embedding.tolist()
|
||||||
|
|
||||||
|
# fit specific encode method for langchain.embeddings.HuggingFaceBgeEmbeddings
|
||||||
|
# TODO: directly support HuggingFaceBgeEmbeddings
|
||||||
|
class TransformersBgeEmbeddings(TransformersEmbeddings):
|
||||||
|
|
||||||
|
def embed(self, text: str, **kwargs):
|
||||||
|
input_ids = self.tokenizer.encode(text, return_tensors="pt", **kwargs)
|
||||||
|
input_ids = input_ids.to(self.model.device)
|
||||||
|
embeddings = self.model(input_ids, return_dict=False)[0].cpu()
|
||||||
|
embeddings = torch.nn.functional.normalize(embeddings[:, 0], p=2, dim=1)
|
||||||
|
return embeddings[0]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue