LLM: modify transformersembeddings.embed() in langchain (#10051)
This commit is contained in:
		
							parent
							
								
									ad050107b3
								
							
						
					
					
						commit
						676d6923f2
					
				
					 2 changed files with 8 additions and 2 deletions
				
			
		| 
						 | 
					@ -135,7 +135,7 @@ class TransformersEmbeddings(BaseModel, Embeddings):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        extra = Extra.forbid
 | 
					        extra = Extra.forbid
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def embed(self, text: str):
 | 
					    def embed(self, text: str, **kwargs):
 | 
				
			||||||
        """Compute doc embeddings using a HuggingFace transformer model.
 | 
					        """Compute doc embeddings using a HuggingFace transformer model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
| 
						 | 
					@ -144,7 +144,7 @@ class TransformersEmbeddings(BaseModel, Embeddings):
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            List of embeddings, one for each text.
 | 
					            List of embeddings, one for each text.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        input_ids = self.tokenizer.encode(text, return_tensors="pt")  # shape: [1, T]
 | 
					        input_ids = self.tokenizer.encode(text, return_tensors="pt", **kwargs)  # shape: [1, T]
 | 
				
			||||||
        embeddings = self.model(input_ids, return_dict=False)[0]  # shape: [1, T, N]
 | 
					        embeddings = self.model(input_ids, return_dict=False)[0]  # shape: [1, T, N]
 | 
				
			||||||
        embeddings = embeddings.squeeze(0).detach().numpy()
 | 
					        embeddings = embeddings.squeeze(0).detach().numpy()
 | 
				
			||||||
        embeddings = np.mean(embeddings, axis=0)
 | 
					        embeddings = np.mean(embeddings, axis=0)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -130,6 +130,12 @@ class Test_Langchain_Transformers_API(TestCase):
 | 
				
			||||||
        self.assertTrue(res)
 | 
					        self.assertTrue(res)
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
 | 
					    def test_embed_kwargs(self):
 | 
				
			||||||
 | 
					        embeddings = TransformersEmbeddings.from_model_id(model_id=self.llama_model_path)
 | 
				
			||||||
 | 
					        encode_kwargs =  {"truncation": True, "max_length": 512}
 | 
				
			||||||
 | 
					        en_texts = ["hello","goodbye"]
 | 
				
			||||||
 | 
					        embeddings.embed(en_texts,**encode_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    pytest.main([__file__])
 | 
					    pytest.main([__file__])
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue