[WIP] LLM transformers api for langchain (#8642)
This commit is contained in:
		
							parent
							
								
									3d5a7484a2
								
							
						
					
					
						commit
						e292dfd970
					
				
					 1 changed files with 81 additions and 0 deletions
				
			
		
							
								
								
									
										81
									
								
								python/llm/test/langchain/test_transformers_api.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								python/llm/test/langchain/test_transformers_api.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,81 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
from bigdl.llm.langchain.llms import TransformersLLM, TransformersPipelineLLM
 | 
			
		||||
from bigdl.llm.langchain.embeddings import TransformersEmbeddings
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from langchain.document_loaders import WebBaseLoader
 | 
			
		||||
from langchain.indexes import VectorstoreIndexCreator
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from langchain.chains.question_answering import load_qa_chain
 | 
			
		||||
from langchain.chains.chat_vector_db.prompts import (CONDENSE_QUESTION_PROMPT,
 | 
			
		||||
                                                     QA_PROMPT)
 | 
			
		||||
from langchain.text_splitter import CharacterTextSplitter
 | 
			
		||||
from langchain.vectorstores import Chroma
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from unittest import TestCase
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Test_Langchain_Transformers_API(TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.auto_model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
 | 
			
		||||
        self.auto_causal_model_path = os.environ.get('ORIGINAL_REPLIT_CODE_PATH')
 | 
			
		||||
        thread_num = os.environ.get('THREAD_NUM')
 | 
			
		||||
        if thread_num is not None:
 | 
			
		||||
            self.n_threads = int(thread_num)
 | 
			
		||||
        else:
 | 
			
		||||
            self.n_threads = 2         
 | 
			
		||||
 | 
			
		||||
    def test_pipeline_llm(self):
 | 
			
		||||
        texts = 'def hello():\n  print("hello world")\n'
 | 
			
		||||
        bigdl_llm = TransformersPipelineLLM.from_model_id(model_id=self.auto_causal_model_path, task='text-generation', model_kwargs={'trust_remote_code': True})
 | 
			
		||||
        
 | 
			
		||||
        output = bigdl_llm(texts)
 | 
			
		||||
        res = "hello()" in output
 | 
			
		||||
        self.assertTrue(res)
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
    def test_qa_chain(self):
 | 
			
		||||
        texts = '''
 | 
			
		||||
AI is a machine’s ability to perform the cognitive functions 
 | 
			
		||||
we associate with human minds, such as perceiving, reasoning, 
 | 
			
		||||
learning, interacting with an environment, problem solving,
 | 
			
		||||
and even exercising creativity. You’ve probably interacted 
 | 
			
		||||
with AI even if you didn’t realize it—voice assistants like Siri 
 | 
			
		||||
and Alexa are founded on AI technology, as are some customer 
 | 
			
		||||
service chatbots that pop up to help you navigate websites.
 | 
			
		||||
        '''
 | 
			
		||||
        text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
 | 
			
		||||
        texts = text_splitter.split_text(texts)
 | 
			
		||||
        query = 'What is AI?'
 | 
			
		||||
        embeddings = TransformersEmbeddings.from_model_id(model_id=self.auto_model_path, model_kwargs={'trust_remote_code': True})
 | 
			
		||||
 | 
			
		||||
        docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]).as_retriever()
 | 
			
		||||
 | 
			
		||||
        #get relavant texts
 | 
			
		||||
        docs = docsearch.get_relevant_documents(query)
 | 
			
		||||
        bigdl_llm = TransformersLLM.from_model_id(model_id=self.auto_model_path, model_kwargs={'trust_remote_code': True})
 | 
			
		||||
        doc_chain = load_qa_chain(bigdl_llm, chain_type="stuff", prompt=QA_PROMPT)
 | 
			
		||||
        output = doc_chain.run(input_documents=docs, question=query)
 | 
			
		||||
        res = "AI" in output
 | 
			
		||||
        self.assertTrue(res)
 | 
			
		||||
        
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    pytest.main([__file__])
 | 
			
		||||
		Loading…
	
		Reference in a new issue