[WIP] LLM transformers api for langchain (#8642)
This commit is contained in:
parent
3d5a7484a2
commit
e292dfd970
1 changed files with 81 additions and 0 deletions
81
python/llm/test/langchain/test_transformers_api.py
Normal file
81
python/llm/test/langchain/test_transformers_api.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from bigdl.llm.langchain.llms import TransformersLLM, TransformersPipelineLLM
|
||||
from bigdl.llm.langchain.embeddings import TransformersEmbeddings
|
||||
|
||||
|
||||
from langchain.document_loaders import WebBaseLoader
|
||||
from langchain.indexes import VectorstoreIndexCreator
|
||||
|
||||
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chains.chat_vector_db.prompts import (CONDENSE_QUESTION_PROMPT,
|
||||
QA_PROMPT)
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
import pytest
|
||||
from unittest import TestCase
|
||||
import os
|
||||
|
||||
|
||||
class Test_Langchain_Transformers_API(TestCase):
|
||||
def setUp(self):
|
||||
self.auto_model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
|
||||
self.auto_causal_model_path = os.environ.get('ORIGINAL_REPLIT_CODE_PATH')
|
||||
thread_num = os.environ.get('THREAD_NUM')
|
||||
if thread_num is not None:
|
||||
self.n_threads = int(thread_num)
|
||||
else:
|
||||
self.n_threads = 2
|
||||
|
||||
def test_pipeline_llm(self):
|
||||
texts = 'def hello():\n print("hello world")\n'
|
||||
bigdl_llm = TransformersPipelineLLM.from_model_id(model_id=self.auto_causal_model_path, task='text-generation', model_kwargs={'trust_remote_code': True})
|
||||
|
||||
output = bigdl_llm(texts)
|
||||
res = "hello()" in output
|
||||
self.assertTrue(res)
|
||||
|
||||
|
||||
def test_qa_chain(self):
|
||||
texts = '''
|
||||
AI is a machine’s ability to perform the cognitive functions
|
||||
we associate with human minds, such as perceiving, reasoning,
|
||||
learning, interacting with an environment, problem solving,
|
||||
and even exercising creativity. You’ve probably interacted
|
||||
with AI even if you didn’t realize it—voice assistants like Siri
|
||||
and Alexa are founded on AI technology, as are some customer
|
||||
service chatbots that pop up to help you navigate websites.
|
||||
'''
|
||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
texts = text_splitter.split_text(texts)
|
||||
query = 'What is AI?'
|
||||
embeddings = TransformersEmbeddings.from_model_id(model_id=self.auto_model_path, model_kwargs={'trust_remote_code': True})
|
||||
|
||||
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]).as_retriever()
|
||||
|
||||
#get relavant texts
|
||||
docs = docsearch.get_relevant_documents(query)
|
||||
bigdl_llm = TransformersLLM.from_model_id(model_id=self.auto_model_path, model_kwargs={'trust_remote_code': True})
|
||||
doc_chain = load_qa_chain(bigdl_llm, chain_type="stuff", prompt=QA_PROMPT)
|
||||
output = doc_chain.run(input_documents=docs, question=query)
|
||||
res = "AI" in output
|
||||
self.assertTrue(res)
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
Loading…
Reference in a new issue