ipex-llm/python/llm/test/langchain/test_transformers_api.py
Jiao Wang 0a06a6e1d4
Update tests for transformers 4.36 (#10858)
* update unit test

* update

* update

* update

* update

* update

* fix gpu attention test

* update

* update

* update

* update

* update

* update

* update example test

* replace replit code

* update

* update

* update

* update

* set safe_serialization false

* perf test

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* delete

* update

* update

* update

* update

* update

* update

* revert

* update
2024-05-24 10:26:38 +08:00

141 lines
6.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from ipex_llm.langchain.llms import TransformersLLM, TransformersPipelineLLM, \
LlamaLLM, BloomLLM
from ipex_llm.langchain.embeddings import TransformersEmbeddings, LlamaEmbeddings, \
BloomEmbeddings
from langchain.document_loaders import WebBaseLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.chat_vector_db.prompts import (CONDENSE_QUESTION_PROMPT,
QA_PROMPT)
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
import pytest
from unittest import TestCase
import os
class Test_Langchain_Transformers_API(TestCase):
def setUp(self):
self.auto_model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
self.auto_causal_model_path = os.environ.get('ORIGINAL_CODESHELL_7B_PATH')
self.llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH')
self.bloom_model_path = os.environ.get('BLOOM_ORIGIN_PATH')
thread_num = os.environ.get('THREAD_NUM')
if thread_num is not None:
self.n_threads = int(thread_num)
else:
self.n_threads = 2
def test_pipeline_llm(self):
texts = 'def hello():\n print("hello world")\n'
bigdl_llm = TransformersPipelineLLM.from_model_id(model_id=self.auto_causal_model_path, task='text-generation', model_kwargs={'trust_remote_code': True})
output = bigdl_llm(texts)
res = "hello()" in output
self.assertTrue(res)
def test_causalLM_embeddings(self):
bigdl_embeddings = BloomEmbeddings(model_path=self.bloom_model_path, model_kwargs={'trust_remote_code': True}, native=False)
text = "This is a test document."
query_result = bigdl_embeddings.embed_query(text)
doc_result = bigdl_embeddings.embed_documents([text])
bigdl_llm = BloomLLM(model_path=self.bloom_model_path, model_kwargs={'trust_remote_code': True}, native=False)
res = bigdl_llm(text)
"""
def test_transformers_llama_embeddings(self):
bigdl_embeddings = TransformersEmbeddings.from_model_id(model_id=self.llama_model_path, model_kwargs={'trust_remote_code': True})
text = "This is a test document."
query_result = bigdl_embeddings.embed_query(text)
doc_result = bigdl_embeddings.embed_documents([text])
bigdl_llm = TransformersLLM.from_model_id(model_id=self.llama_model_path, model_kwargs={'trust_remote_code': True})
res = bigdl_llm(text)
"""
def test_qa_chain(self):
texts = '''
AI is a machines ability to perform the cognitive functions
we associate with human minds, such as perceiving, reasoning,
learning, interacting with an environment, problem solving,
and even exercising creativity. Youve probably interacted
with AI even if you didnt realize it—voice assistants like Siri
and Alexa are founded on AI technology, as are some customer
service chatbots that pop up to help you navigate websites.
'''
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_text(texts)
query = 'What is AI?'
embeddings = TransformersEmbeddings.from_model_id(model_id=self.auto_model_path, model_kwargs={'trust_remote_code': True})
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]).as_retriever()
#get relavant texts
docs = docsearch.get_relevant_documents(query)
bigdl_llm = TransformersLLM.from_model_id(model_id=self.auto_model_path, model_kwargs={'trust_remote_code': True})
doc_chain = load_qa_chain(bigdl_llm, chain_type="stuff", prompt=QA_PROMPT)
output = doc_chain.run(input_documents=docs, question=query)
res = "AI" in output
self.assertTrue(res)
"""
def test_qa_chain_causalLM(self):
texts = '''
AI is a machines ability to perform the cognitive functions
we associate with human minds, such as perceiving, reasoning,
learning, interacting with an environment, problem solving,
and even exercising creativity. Youve probably interacted
with AI even if you didnt realize it—voice assistants like Siri
and Alexa are founded on AI technology, as are some customer
service chatbots that pop up to help you navigate websites.
'''
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_text(texts)
query = 'What is AI?'
embeddings = LlamaEmbeddings(model_path=self.llama_model_path, model_kwargs={'trust_remote_code': True}, native=False)
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]).as_retriever()
#get relavant texts
docs = docsearch.get_relevant_documents(query)
bigdl_llm = LlamaLLM(model_path=self.llama_model_path, model_kwargs={'trust_remote_code': True}, native=False)
doc_chain = load_qa_chain(bigdl_llm, chain_type="stuff", prompt=QA_PROMPT)
output = doc_chain.run(input_documents=docs, question=query)
res = "AI" in output
self.assertTrue(res)
"""
def test_embed_kwargs(self):
embeddings = TransformersEmbeddings.from_model_id(model_id=self.llama_model_path)
encode_kwargs = {"truncation": True, "max_length": 512}
en_texts = ["hello","goodbye"]
embeddings.embed(en_texts,**encode_kwargs)
if __name__ == '__main__':
pytest.main([__file__])