LLM: modify transformersembeddings.embed() in langchain (#10051)
This commit is contained in:
parent
ad050107b3
commit
676d6923f2
2 changed files with 8 additions and 2 deletions
|
|
@ -135,7 +135,7 @@ class TransformersEmbeddings(BaseModel, Embeddings):
|
||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
def embed(self, text: str):
|
def embed(self, text: str, **kwargs):
|
||||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -144,7 +144,7 @@ class TransformersEmbeddings(BaseModel, Embeddings):
|
||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
input_ids = self.tokenizer.encode(text, return_tensors="pt") # shape: [1, T]
|
input_ids = self.tokenizer.encode(text, return_tensors="pt", **kwargs) # shape: [1, T]
|
||||||
embeddings = self.model(input_ids, return_dict=False)[0] # shape: [1, T, N]
|
embeddings = self.model(input_ids, return_dict=False)[0] # shape: [1, T, N]
|
||||||
embeddings = embeddings.squeeze(0).detach().numpy()
|
embeddings = embeddings.squeeze(0).detach().numpy()
|
||||||
embeddings = np.mean(embeddings, axis=0)
|
embeddings = np.mean(embeddings, axis=0)
|
||||||
|
|
|
||||||
|
|
@ -130,6 +130,12 @@ class Test_Langchain_Transformers_API(TestCase):
|
||||||
self.assertTrue(res)
|
self.assertTrue(res)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def test_embed_kwargs(self):
|
||||||
|
embeddings = TransformersEmbeddings.from_model_id(model_id=self.llama_model_path)
|
||||||
|
encode_kwargs = {"truncation": True, "max_length": 512}
|
||||||
|
en_texts = ["hello","goodbye"]
|
||||||
|
embeddings.embed(en_texts,**encode_kwargs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue