From 676d6923f2cf17244a5f8a639dfad81c71a45e27 Mon Sep 17 00:00:00 2001 From: Zhicun <59141989+ivy-lv11@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:42:10 +0800 Subject: [PATCH] LLM: modify transformersembeddings.embed() in langchain (#10051) --- .../llm/langchain/embeddings/transformersembeddings.py | 4 ++-- python/llm/test/langchain/test_transformers_api.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py b/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py index d3974893..bc57ef29 100644 --- a/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py +++ b/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py @@ -135,7 +135,7 @@ class TransformersEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - def embed(self, text: str): + def embed(self, text: str, **kwargs): """Compute doc embeddings using a HuggingFace transformer model. Args: @@ -144,7 +144,7 @@ class TransformersEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - input_ids = self.tokenizer.encode(text, return_tensors="pt") # shape: [1, T] + input_ids = self.tokenizer.encode(text, return_tensors="pt", **kwargs) # shape: [1, T] embeddings = self.model(input_ids, return_dict=False)[0] # shape: [1, T, N] embeddings = embeddings.squeeze(0).detach().numpy() embeddings = np.mean(embeddings, axis=0) diff --git a/python/llm/test/langchain/test_transformers_api.py b/python/llm/test/langchain/test_transformers_api.py index 90113460..c8e2ac53 100644 --- a/python/llm/test/langchain/test_transformers_api.py +++ b/python/llm/test/langchain/test_transformers_api.py @@ -129,6 +129,12 @@ class Test_Langchain_Transformers_API(TestCase): res = "AI" in output self.assertTrue(res) """ + + def test_embed_kwargs(self): + embeddings = TransformersEmbeddings.from_model_id(model_id=self.llama_model_path) + encode_kwargs = {"truncation": True, "max_length": 512} + en_texts = ["hello","goodbye"] + embeddings.embed(en_texts,**encode_kwargs) if __name__ == '__main__':