LLM: modify transformersembeddings.embed() in langchain (#10051)
This commit is contained in:
		
							parent
							
								
									ad050107b3
								
							
						
					
					
						commit
						676d6923f2
					
				
					 2 changed files with 8 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -135,7 +135,7 @@ class TransformersEmbeddings(BaseModel, Embeddings):
 | 
			
		|||
 | 
			
		||||
        extra = Extra.forbid
 | 
			
		||||
    
 | 
			
		||||
    def embed(self, text: str):
 | 
			
		||||
    def embed(self, text: str, **kwargs):
 | 
			
		||||
        """Compute doc embeddings using a HuggingFace transformer model.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
| 
						 | 
				
			
			@ -144,7 +144,7 @@ class TransformersEmbeddings(BaseModel, Embeddings):
 | 
			
		|||
        Returns:
 | 
			
		||||
            List of embeddings, one for each text.
 | 
			
		||||
        """
 | 
			
		||||
        input_ids = self.tokenizer.encode(text, return_tensors="pt")  # shape: [1, T]
 | 
			
		||||
        input_ids = self.tokenizer.encode(text, return_tensors="pt", **kwargs)  # shape: [1, T]
 | 
			
		||||
        embeddings = self.model(input_ids, return_dict=False)[0]  # shape: [1, T, N]
 | 
			
		||||
        embeddings = embeddings.squeeze(0).detach().numpy()
 | 
			
		||||
        embeddings = np.mean(embeddings, axis=0)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -129,6 +129,12 @@ class Test_Langchain_Transformers_API(TestCase):
 | 
			
		|||
        res = "AI" in output
 | 
			
		||||
        self.assertTrue(res)
 | 
			
		||||
    """
 | 
			
		||||
    
 | 
			
		||||
    def test_embed_kwargs(self):
 | 
			
		||||
        embeddings = TransformersEmbeddings.from_model_id(model_id=self.llama_model_path)
 | 
			
		||||
        encode_kwargs =  {"truncation": True, "max_length": 512}
 | 
			
		||||
        en_texts = ["hello","goodbye"]
 | 
			
		||||
        embeddings.embed(en_texts,**encode_kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue