diff --git a/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py b/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py index d3974893..bc57ef29 100644 --- a/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py +++ b/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py @@ -135,7 +135,7 @@ class TransformersEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - def embed(self, text: str): + def embed(self, text: str, **kwargs): """Compute doc embeddings using a HuggingFace transformer model. Args: @@ -144,7 +144,7 @@ class TransformersEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - input_ids = self.tokenizer.encode(text, return_tensors="pt") # shape: [1, T] + input_ids = self.tokenizer.encode(text, return_tensors="pt", **kwargs) # shape: [1, T] embeddings = self.model(input_ids, return_dict=False)[0] # shape: [1, T, N] embeddings = embeddings.squeeze(0).detach().numpy() embeddings = np.mean(embeddings, axis=0) diff --git a/python/llm/test/langchain/test_transformers_api.py b/python/llm/test/langchain/test_transformers_api.py index 90113460..c8e2ac53 100644 --- a/python/llm/test/langchain/test_transformers_api.py +++ b/python/llm/test/langchain/test_transformers_api.py @@ -129,6 +129,12 @@ class Test_Langchain_Transformers_API(TestCase): res = "AI" in output self.assertTrue(res) """ + + def test_embed_kwargs(self): + embeddings = TransformersEmbeddings.from_model_id(model_id=self.llama_model_path) + encode_kwargs = {"truncation": True, "max_length": 512} + en_texts = ["hello","goodbye"] + embeddings.embed(en_texts,**encode_kwargs) if __name__ == '__main__':