Add LlamaIndex RAG (#10263)
* run demo * format code * add llamaindex * add custom LLM with bigdl * update * add readme * begin ut * add unit test * add license * add license * revised * update * modify docs * remove data folder * update * modify prompt * fixed * fixed * fixed
This commit is contained in:
		
							parent
							
								
									5d7243067c
								
							
						
					
					
						commit
						4e6cc424f1
					
				
					 6 changed files with 898 additions and 0 deletions
				
			
		
							
								
								
									
										60
									
								
								python/llm/example/CPU/LlamaIndex/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								python/llm/example/CPU/LlamaIndex/README.md
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,60 @@
 | 
				
			||||||
 | 
					# LlamaIndex Examples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The examples here show how to use LlamaIndex with `bigdl-llm`.
 | 
				
			||||||
 | 
					The RAG example is modified from the [demo](https://docs.llamaindex.ai/en/stable/examples/low_level/oss_ingestion_retrieval.html). 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Install bigdl-llm
 | 
				
			||||||
 | 
					Follow the instructions in [Install](https://github.com/intel-analytics/BigDL/tree/main/python/llm#install).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Install Required Dependencies for llamaindex examples. 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Install Site-packages
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					pip install llama-index-readers-file
 | 
				
			||||||
 | 
					pip install llama-index-vector-stores-postgres
 | 
				
			||||||
 | 
					pip install llama-index-embeddings-huggingface
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Install Postgres
 | 
				
			||||||
 | 
					> Note: There are plenty of open-source databases you can use. Here we provide an example using Postgres. 
 | 
				
			||||||
 | 
					* Download and install postgres by running the commands below. 
 | 
				
			||||||
 | 
					    ```bash
 | 
				
			||||||
 | 
					    sudo apt-get install postgresql-client
 | 
				
			||||||
 | 
					    sudo apt-get install postgresql
 | 
				
			||||||
 | 
					    ```
 | 
				
			||||||
 | 
					* Initilize postgres. 
 | 
				
			||||||
 | 
					    ```bash
 | 
				
			||||||
 | 
					    sudo su - postgres
 | 
				
			||||||
 | 
					    psql
 | 
				
			||||||
 | 
					    ```
 | 
				
			||||||
 | 
					    After running the commands in the shell, we reach the console of postgres. Then we can add a role like the following
 | 
				
			||||||
 | 
					    ```bash
 | 
				
			||||||
 | 
					    CREATE ROLE <user> WITH LOGIN PASSWORD '<password>';
 | 
				
			||||||
 | 
					    ALTER ROLE <user> SUPERUSER;    
 | 
				
			||||||
 | 
					    ```
 | 
				
			||||||
 | 
					* Install pgvector according to the [page](https://github.com/pgvector/pgvector). If you encounter problem about the installation, please refer to the [notes](https://github.com/pgvector/pgvector#installation-notes) which may be helpful. 
 | 
				
			||||||
 | 
					* Download the database.
 | 
				
			||||||
 | 
					    ```bash
 | 
				
			||||||
 | 
					    mkdir data
 | 
				
			||||||
 | 
					    wget --user-agent "Mozilla" "https://arxiv.org/pdf/2307.09288.pdf" -O "data/llama2.pdf"
 | 
				
			||||||
 | 
					    ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Run the examples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Retrieval-augmented Generation
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					python rag.py -m MODEL_PATH -e EMBEDDING_MODEL_PATH -u USERNAME -p PASSWORD -q QUESTION -d DATA
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					arguments info:
 | 
				
			||||||
 | 
					- `-m MODEL_PATH`: **required**, path to the llama model
 | 
				
			||||||
 | 
					- `-e EMBEDDING_MODEL_PATH`: path to the embedding model
 | 
				
			||||||
 | 
					- `-u USERNAME`: username in the postgres database
 | 
				
			||||||
 | 
					- `-p PASSWORD`: password in the postgres database
 | 
				
			||||||
 | 
					- `-q QUESTION`: question you want to ask
 | 
				
			||||||
 | 
					- `-d DATA`: path to data used during retrieval
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Here is the sample output when applying Llama-2-7b-chat-hf as the generatio model when we ask "How does Llama 2 perform compared to other open-source models?" and use llama.pdf as database. 
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					Llama 2 performs better than most open-source models on the benchmarks we tested. Specifically, it outperforms all open-source models on MMLU and BBH, and is close to GPT-3.5 on these benchmarks. Additionally, Llama 2 is on par or better than PaLM-2-L on almost all benchmarks. The only exception is the coding benchmarks, where Llama 2 lags significantly behind GPT-4 and PaLM-2-L. Overall, Llama 2 demonstrates strong performance on a wide range of natural language processing tasks.
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
							
								
								
									
										247
									
								
								python/llm/example/CPU/LlamaIndex/rag.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										247
									
								
								python/llm/example/CPU/LlamaIndex/rag.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,247 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from llama_index.embeddings.huggingface import HuggingFaceEmbedding
 | 
				
			||||||
 | 
					from sqlalchemy import make_url
 | 
				
			||||||
 | 
					from llama_index.vector_stores.postgres import PGVectorStore
 | 
				
			||||||
 | 
					# from llama_index.llms.llama_cpp import LlamaCPP
 | 
				
			||||||
 | 
					import psycopg2
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					from llama_index.readers.file import PyMuPDFReader
 | 
				
			||||||
 | 
					from llama_index.core.schema import NodeWithScore
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					from llama_index.core.query_engine import RetrieverQueryEngine
 | 
				
			||||||
 | 
					from llama_index.core import QueryBundle
 | 
				
			||||||
 | 
					from llama_index.core.retrievers import BaseRetriever
 | 
				
			||||||
 | 
					from typing import Any, List
 | 
				
			||||||
 | 
					from llama_index.core.node_parser import SentenceSplitter
 | 
				
			||||||
 | 
					from llama_index.core.vector_stores import VectorStoreQuery
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_vector_database(username, password):
 | 
				
			||||||
 | 
					    db_name = "example_db"
 | 
				
			||||||
 | 
					    host = "localhost"
 | 
				
			||||||
 | 
					    password = password
 | 
				
			||||||
 | 
					    port = "5432"
 | 
				
			||||||
 | 
					    user = username
 | 
				
			||||||
 | 
					    # conn = psycopg2.connect(connection_string)
 | 
				
			||||||
 | 
					    conn = psycopg2.connect(
 | 
				
			||||||
 | 
					        dbname="postgres",
 | 
				
			||||||
 | 
					        host=host,
 | 
				
			||||||
 | 
					        password=password,
 | 
				
			||||||
 | 
					        port=port,
 | 
				
			||||||
 | 
					        user=user,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    conn.autocommit = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with conn.cursor() as c:
 | 
				
			||||||
 | 
					        c.execute(f"DROP DATABASE IF EXISTS {db_name}")
 | 
				
			||||||
 | 
					        c.execute(f"CREATE DATABASE {db_name}")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    vector_store = PGVectorStore.from_params(
 | 
				
			||||||
 | 
					        database=db_name,
 | 
				
			||||||
 | 
					        host=host,
 | 
				
			||||||
 | 
					        password=password,
 | 
				
			||||||
 | 
					        port=port,
 | 
				
			||||||
 | 
					        user=user,
 | 
				
			||||||
 | 
					        table_name="llama2_paper",
 | 
				
			||||||
 | 
					        embed_dim=384,  # openai embedding dimension
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    return vector_store
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_data(data_path):
 | 
				
			||||||
 | 
					    loader = PyMuPDFReader()
 | 
				
			||||||
 | 
					    documents = loader.load(file_path=data_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    text_parser = SentenceSplitter(
 | 
				
			||||||
 | 
					        chunk_size=1024,
 | 
				
			||||||
 | 
					        # separator=" ",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    text_chunks = []
 | 
				
			||||||
 | 
					    # maintain relationship with source doc index, to help inject doc metadata in (3)
 | 
				
			||||||
 | 
					    doc_idxs = []
 | 
				
			||||||
 | 
					    for doc_idx, doc in enumerate(documents):
 | 
				
			||||||
 | 
					        cur_text_chunks = text_parser.split_text(doc.text)
 | 
				
			||||||
 | 
					        text_chunks.extend(cur_text_chunks)
 | 
				
			||||||
 | 
					        doc_idxs.extend([doc_idx] * len(cur_text_chunks))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from llama_index.core.schema import TextNode
 | 
				
			||||||
 | 
					    nodes = []
 | 
				
			||||||
 | 
					    for idx, text_chunk in enumerate(text_chunks):
 | 
				
			||||||
 | 
					        node = TextNode(
 | 
				
			||||||
 | 
					            text=text_chunk,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        src_doc = documents[doc_idxs[idx]]
 | 
				
			||||||
 | 
					        node.metadata = src_doc.metadata
 | 
				
			||||||
 | 
					        nodes.append(node)
 | 
				
			||||||
 | 
					    return nodes
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VectorDBRetriever(BaseRetriever):
 | 
				
			||||||
 | 
					    """Retriever over a postgres vector store."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        vector_store: PGVectorStore,
 | 
				
			||||||
 | 
					        embed_model: Any,
 | 
				
			||||||
 | 
					        query_mode: str = "default",
 | 
				
			||||||
 | 
					        similarity_top_k: int = 2,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        """Init params."""
 | 
				
			||||||
 | 
					        self._vector_store = vector_store
 | 
				
			||||||
 | 
					        self._embed_model = embed_model
 | 
				
			||||||
 | 
					        self._query_mode = query_mode
 | 
				
			||||||
 | 
					        self._similarity_top_k = similarity_top_k
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
 | 
				
			||||||
 | 
					        """Retrieve."""
 | 
				
			||||||
 | 
					        query_embedding = self._embed_model.get_query_embedding(
 | 
				
			||||||
 | 
					            query_bundle.query_str
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        vector_store_query = VectorStoreQuery(
 | 
				
			||||||
 | 
					            query_embedding=query_embedding,
 | 
				
			||||||
 | 
					            similarity_top_k=self._similarity_top_k,
 | 
				
			||||||
 | 
					            mode=self._query_mode,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        query_result = self._vector_store.query(vector_store_query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        nodes_with_scores = []
 | 
				
			||||||
 | 
					        for index, node in enumerate(query_result.nodes):
 | 
				
			||||||
 | 
					            score: Optional[float] = None
 | 
				
			||||||
 | 
					            if query_result.similarities is not None:
 | 
				
			||||||
 | 
					                score = query_result.similarities[index]
 | 
				
			||||||
 | 
					            nodes_with_scores.append(NodeWithScore(node=node, score=score))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return nodes_with_scores
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					def completion_to_prompt(completion):
 | 
				
			||||||
 | 
					    return f"<|system|>\n</s>\n<|user|>\n{completion}</s>\n<|assistant|>\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Transform a list of chat messages into zephyr-specific input
 | 
				
			||||||
 | 
					def messages_to_prompt(messages):
 | 
				
			||||||
 | 
					    prompt = ""
 | 
				
			||||||
 | 
					    for message in messages:
 | 
				
			||||||
 | 
					        if message.role == "system":
 | 
				
			||||||
 | 
					            prompt += f"<|system|>\n{message.content}</s>\n"
 | 
				
			||||||
 | 
					        elif message.role == "user":
 | 
				
			||||||
 | 
					            prompt += f"<|user|>\n{message.content}</s>\n"
 | 
				
			||||||
 | 
					        elif message.role == "assistant":
 | 
				
			||||||
 | 
					            prompt += f"<|assistant|>\n{message.content}</s>\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ensure we start with a system prompt, insert blank if needed
 | 
				
			||||||
 | 
					    if not prompt.startswith("<|system|>\n"):
 | 
				
			||||||
 | 
					        prompt = "<|system|>\n</s>\n" + prompt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # add final assistant prompt
 | 
				
			||||||
 | 
					    prompt = prompt + "<|assistant|>\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return prompt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main(args):
 | 
				
			||||||
 | 
					    embed_model = HuggingFaceEmbedding(model_name=args.embedding_model_path)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Use custom LLM in BigDL
 | 
				
			||||||
 | 
					    from bigdl.llm.llamaindex.llms import BigdlLLM
 | 
				
			||||||
 | 
					    llm = BigdlLLM(
 | 
				
			||||||
 | 
					        model_name=args.model_path,
 | 
				
			||||||
 | 
					        tokenizer_name=args.model_path,
 | 
				
			||||||
 | 
					        context_window=512,
 | 
				
			||||||
 | 
					        max_new_tokens=32,
 | 
				
			||||||
 | 
					        generate_kwargs={"temperature": 0.7, "do_sample": False},
 | 
				
			||||||
 | 
					        model_kwargs={},
 | 
				
			||||||
 | 
					        messages_to_prompt=messages_to_prompt,
 | 
				
			||||||
 | 
					        completion_to_prompt=completion_to_prompt,
 | 
				
			||||||
 | 
					        device_map="cpu",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    vector_store = load_vector_database(username=args.user, password=args.password)
 | 
				
			||||||
 | 
					    nodes = load_data(data_path=args.data)
 | 
				
			||||||
 | 
					    for node in nodes:
 | 
				
			||||||
 | 
					        node_embedding = embed_model.get_text_embedding(
 | 
				
			||||||
 | 
					            node.get_content(metadata_mode="all")
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        node.embedding = node_embedding
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    vector_store.add(nodes)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # query_str = "Can you tell me about the key concepts for safety finetuning"
 | 
				
			||||||
 | 
					    query_str = "Explain about the training data for Llama 2"
 | 
				
			||||||
 | 
					    query_embedding = embed_model.get_query_embedding(query_str)
 | 
				
			||||||
 | 
					    # construct vector store query
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    query_mode = "default"
 | 
				
			||||||
 | 
					    # query_mode = "sparse"
 | 
				
			||||||
 | 
					    # query_mode = "hybrid"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vector_store_query = VectorStoreQuery(
 | 
				
			||||||
 | 
					        query_embedding=query_embedding, similarity_top_k=2, mode=query_mode
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    # returns a VectorStoreQueryResult
 | 
				
			||||||
 | 
					    query_result = vector_store.query(vector_store_query)
 | 
				
			||||||
 | 
					    # print("Retrieval Results: ")
 | 
				
			||||||
 | 
					    # print(query_result.nodes[0].get_content())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    nodes_with_scores = []
 | 
				
			||||||
 | 
					    for index, node in enumerate(query_result.nodes):
 | 
				
			||||||
 | 
					        score: Optional[float] = None
 | 
				
			||||||
 | 
					        if query_result.similarities is not None:
 | 
				
			||||||
 | 
					            score = query_result.similarities[index]
 | 
				
			||||||
 | 
					        nodes_with_scores.append(NodeWithScore(node=node, score=score))
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    retriever = VectorDBRetriever(
 | 
				
			||||||
 | 
					        vector_store, embed_model, query_mode="default", similarity_top_k=1
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    query_engine = RetrieverQueryEngine.from_args(retriever, llm=llm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # query_str = "How does Llama 2 perform compared to other open-source models?"
 | 
				
			||||||
 | 
					    query_str = args.question
 | 
				
			||||||
 | 
					    response = query_engine.query(query_str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print("------------RESPONSE GENERATION---------------------")
 | 
				
			||||||
 | 
					    print(str(response))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser(description='LlamaIndex BigdlLLM Example')
 | 
				
			||||||
 | 
					    parser.add_argument('-m','--model-path', type=str, required=True,
 | 
				
			||||||
 | 
					                        help='the path to transformers model')
 | 
				
			||||||
 | 
					    parser.add_argument('-q', '--question', type=str, default='How does Llama 2 perform compared to other open-source models?',
 | 
				
			||||||
 | 
					                        help='qustion you want to ask.')
 | 
				
			||||||
 | 
					    parser.add_argument('-d','--data',type=str, default='./data/llama2.pdf',
 | 
				
			||||||
 | 
					                        help="the data used during retrieval")
 | 
				
			||||||
 | 
					    parser.add_argument('-u', '--user', type=str, required=True,
 | 
				
			||||||
 | 
					                        help="user name in the database postgres")
 | 
				
			||||||
 | 
					    parser.add_argument('-p','--password', type=str, required=True,
 | 
				
			||||||
 | 
					                        help="the password of the user in the database")
 | 
				
			||||||
 | 
					    parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en",
 | 
				
			||||||
 | 
					                        help="the path to embedding model path")
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    main(args)
 | 
				
			||||||
							
								
								
									
										20
									
								
								python/llm/src/bigdl/llm/llamaindex/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								python/llm/src/bigdl/llm/llamaindex/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,20 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# This would makes sure Python is aware there is more than one sub-package within bigdl,
 | 
				
			||||||
 | 
					# physically located elsewhere.
 | 
				
			||||||
 | 
					# Otherwise there would be module not found error in non-pip's setting as Python would
 | 
				
			||||||
 | 
					# only search the first bigdl package and end up finding only one sub-package.
 | 
				
			||||||
							
								
								
									
										34
									
								
								python/llm/src/bigdl/llm/llamaindex/llms/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								python/llm/src/bigdl/llm/llamaindex/llms/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,34 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# This would makes sure Python is aware there is more than one sub-package within bigdl,
 | 
				
			||||||
 | 
					# physically located elsewhere.
 | 
				
			||||||
 | 
					# Otherwise there would be module not found error in non-pip's setting as Python would
 | 
				
			||||||
 | 
					# only search the first bigdl package and end up finding only one sub-package.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""Wrappers on top of large language models APIs."""
 | 
				
			||||||
 | 
					from typing import Dict, Type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .bigdlllm import *
 | 
				
			||||||
 | 
					from llama_index.core.base.llms.base import BaseLLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = [
 | 
				
			||||||
 | 
					    "BigdlLLM",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
 | 
				
			||||||
 | 
					    "BigdlLLM": BigdlLLM,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										449
									
								
								python/llm/src/bigdl/llm/llamaindex/llms/bigdlllm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										449
									
								
								python/llm/src/bigdl/llm/llamaindex/llms/bigdlllm.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,449 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# The file is modified from
 | 
				
			||||||
 | 
					# https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/base/llms/base.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# The MIT License
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Copyright (c) Harrison Chase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
				
			||||||
 | 
					# of this software and associated documentation files (the "Software"), to deal
 | 
				
			||||||
 | 
					# in the Software without restriction, including without limitation the rights
 | 
				
			||||||
 | 
					# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
				
			||||||
 | 
					# copies of the Software, and to permit persons to whom the Software is
 | 
				
			||||||
 | 
					# furnished to do so, subject to the following conditions:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# The above copyright notice and this permission notice shall be included in
 | 
				
			||||||
 | 
					# all copies or substantial portions of the Software.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
				
			||||||
 | 
					# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
				
			||||||
 | 
					# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
				
			||||||
 | 
					# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
				
			||||||
 | 
					# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
				
			||||||
 | 
					# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 | 
				
			||||||
 | 
					# THE SOFTWARE.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					from threading import Thread
 | 
				
			||||||
 | 
					from typing import Any, Callable, Dict, List, Optional, Sequence, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from huggingface_hub import AsyncInferenceClient, InferenceClient, model_info
 | 
				
			||||||
 | 
					from llama_index.core.base.llms.types import (
 | 
				
			||||||
 | 
					    ChatMessage,
 | 
				
			||||||
 | 
					    ChatResponse,
 | 
				
			||||||
 | 
					    ChatResponseAsyncGen,
 | 
				
			||||||
 | 
					    ChatResponseGen,
 | 
				
			||||||
 | 
					    CompletionResponse,
 | 
				
			||||||
 | 
					    CompletionResponseAsyncGen,
 | 
				
			||||||
 | 
					    CompletionResponseGen,
 | 
				
			||||||
 | 
					    LLMMetadata,
 | 
				
			||||||
 | 
					    MessageRole,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from llama_index.core.bridge.pydantic import Field, PrivateAttr
 | 
				
			||||||
 | 
					from llama_index.core.callbacks import CallbackManager
 | 
				
			||||||
 | 
					from llama_index.core.constants import (
 | 
				
			||||||
 | 
					    DEFAULT_CONTEXT_WINDOW,
 | 
				
			||||||
 | 
					    DEFAULT_NUM_OUTPUTS,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from llama_index.core.llms.callbacks import (
 | 
				
			||||||
 | 
					    llm_chat_callback,
 | 
				
			||||||
 | 
					    llm_completion_callback,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from llama_index.core.llms.custom import CustomLLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from llama_index.core.base.llms.generic_utils import (
 | 
				
			||||||
 | 
					    messages_to_prompt as generic_messages_to_prompt,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from llama_index.core.prompts.base import PromptTemplate
 | 
				
			||||||
 | 
					from llama_index.core.types import BaseOutputParser, PydanticProgramMode
 | 
				
			||||||
 | 
					from transformers import (
 | 
				
			||||||
 | 
					    StoppingCriteria,
 | 
				
			||||||
 | 
					    StoppingCriteriaList,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from transformers import AutoTokenizer, LlamaTokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DEFAULT_HUGGINGFACE_MODEL = "meta-llama/Llama-2-7b-chat-hf"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BigdlLLM(CustomLLM):
 | 
				
			||||||
 | 
					    """Wrapper around the BigDL-LLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Example:
 | 
				
			||||||
 | 
					        .. code-block:: python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            from bigdl.llm.llamaindex.llms import BigdlLLM
 | 
				
			||||||
 | 
					            llm = BigdlLLM(model_path="/path/to/llama/model")
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_name: str = Field(
 | 
				
			||||||
 | 
					        default=DEFAULT_HUGGINGFACE_MODEL,
 | 
				
			||||||
 | 
					        description=(
 | 
				
			||||||
 | 
					            "The model name to use from HuggingFace. "
 | 
				
			||||||
 | 
					            "Unused if `model` is passed in directly."
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    context_window: int = Field(
 | 
				
			||||||
 | 
					        default=DEFAULT_CONTEXT_WINDOW,
 | 
				
			||||||
 | 
					        description="The maximum number of tokens available for input.",
 | 
				
			||||||
 | 
					        gt=0,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    max_new_tokens: int = Field(
 | 
				
			||||||
 | 
					        default=DEFAULT_NUM_OUTPUTS,
 | 
				
			||||||
 | 
					        description="The maximum number of tokens to generate.",
 | 
				
			||||||
 | 
					        gt=0,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    system_prompt: str = Field(
 | 
				
			||||||
 | 
					        default="",
 | 
				
			||||||
 | 
					        description=(
 | 
				
			||||||
 | 
					            "The system prompt, containing any extra instructions or context. "
 | 
				
			||||||
 | 
					            "The model card on HuggingFace should specify if this is needed."
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    query_wrapper_prompt: PromptTemplate = Field(
 | 
				
			||||||
 | 
					        default=PromptTemplate("{query_str}"),
 | 
				
			||||||
 | 
					        description=(
 | 
				
			||||||
 | 
					            "The query wrapper prompt, containing the query placeholder. "
 | 
				
			||||||
 | 
					            "The model card on HuggingFace should specify if this is needed. "
 | 
				
			||||||
 | 
					            "Should contain a `{query_str}` placeholder."
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    tokenizer_name: str = Field(
 | 
				
			||||||
 | 
					        default=DEFAULT_HUGGINGFACE_MODEL,
 | 
				
			||||||
 | 
					        description=(
 | 
				
			||||||
 | 
					            "The name of the tokenizer to use from HuggingFace. "
 | 
				
			||||||
 | 
					            "Unused if `tokenizer` is passed in directly."
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    device_map: str = Field(
 | 
				
			||||||
 | 
					        default="auto", description="The device_map to use. Defaults to 'auto'."
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    stopping_ids: List[int] = Field(
 | 
				
			||||||
 | 
					        default_factory=list,
 | 
				
			||||||
 | 
					        description=(
 | 
				
			||||||
 | 
					            "The stopping ids to use. "
 | 
				
			||||||
 | 
					            "Generation stops when these token IDs are predicted."
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    tokenizer_outputs_to_remove: list = Field(
 | 
				
			||||||
 | 
					        default_factory=list,
 | 
				
			||||||
 | 
					        description=(
 | 
				
			||||||
 | 
					            "The outputs to remove from the tokenizer. "
 | 
				
			||||||
 | 
					            "Sometimes huggingface tokenizers return extra inputs that cause errors."
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    tokenizer_kwargs: dict = Field(
 | 
				
			||||||
 | 
					        default_factory=dict, description="The kwargs to pass to the tokenizer."
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    model_kwargs: dict = Field(
 | 
				
			||||||
 | 
					        default_factory=dict,
 | 
				
			||||||
 | 
					        description="The kwargs to pass to the model during initialization.",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    generate_kwargs: dict = Field(
 | 
				
			||||||
 | 
					        default_factory=dict,
 | 
				
			||||||
 | 
					        description="The kwargs to pass to the model during generation.",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    is_chat_model: bool = Field(
 | 
				
			||||||
 | 
					        default=False,
 | 
				
			||||||
 | 
					        description=(
 | 
				
			||||||
 | 
					            LLMMetadata.__fields__["is_chat_model"].field_info.description
 | 
				
			||||||
 | 
					            + " Be sure to verify that you either pass an appropriate tokenizer "
 | 
				
			||||||
 | 
					            "that can convert prompts to properly formatted chat messages or a "
 | 
				
			||||||
 | 
					            "`messages_to_prompt` that does so."
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _model: Any = PrivateAttr()
 | 
				
			||||||
 | 
					    _tokenizer: Any = PrivateAttr()
 | 
				
			||||||
 | 
					    _stopping_criteria: Any = PrivateAttr()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        context_window: int = DEFAULT_CONTEXT_WINDOW,
 | 
				
			||||||
 | 
					        max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
 | 
				
			||||||
 | 
					        query_wrapper_prompt: Union[str, PromptTemplate]="{query_str}",
 | 
				
			||||||
 | 
					        tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL,
 | 
				
			||||||
 | 
					        model_name: str = DEFAULT_HUGGINGFACE_MODEL,
 | 
				
			||||||
 | 
					        model: Optional[Any] = None,
 | 
				
			||||||
 | 
					        tokenizer: Optional[Any] = None,
 | 
				
			||||||
 | 
					        device_map: Optional[str] = "auto",
 | 
				
			||||||
 | 
					        stopping_ids: Optional[List[int]] = None,
 | 
				
			||||||
 | 
					        tokenizer_kwargs: Optional[dict] = None,
 | 
				
			||||||
 | 
					        tokenizer_outputs_to_remove: Optional[list] = None,
 | 
				
			||||||
 | 
					        model_kwargs: Optional[dict] = None,
 | 
				
			||||||
 | 
					        generate_kwargs: Optional[dict] = None,
 | 
				
			||||||
 | 
					        is_chat_model: Optional[bool] = False,
 | 
				
			||||||
 | 
					        callback_manager: Optional[CallbackManager] = None,
 | 
				
			||||||
 | 
					        system_prompt: str = "",
 | 
				
			||||||
 | 
					        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]]=None,
 | 
				
			||||||
 | 
					        completion_to_prompt: Optional[Callable[[str], str]]=None,
 | 
				
			||||||
 | 
					        pydantic_program_mode: PydanticProgramMode=PydanticProgramMode.DEFAULT,
 | 
				
			||||||
 | 
					        output_parser: Optional[BaseOutputParser] = None,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Construct BigdlLLM.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            context_window: The maximum number of tokens available for input.
 | 
				
			||||||
 | 
					            max_new_tokens: The maximum number of tokens to generate.
 | 
				
			||||||
 | 
					            query_wrapper_prompt: The query wrapper prompt, containing the query placeholder.
 | 
				
			||||||
 | 
					                        Should contain a `{query_str}` placeholder.
 | 
				
			||||||
 | 
					            tokenizer_name: The name of the tokenizer to use from HuggingFace.
 | 
				
			||||||
 | 
					                        Unused if `tokenizer` is passed in directly.
 | 
				
			||||||
 | 
					            model_name: The model name to use from HuggingFace.
 | 
				
			||||||
 | 
					                        Unused if `model` is passed in directly.
 | 
				
			||||||
 | 
					            model: The HuggingFace model.
 | 
				
			||||||
 | 
					            tokenizer: The tokenizer.
 | 
				
			||||||
 | 
					            device_map: The device_map to use. Defaults to 'auto'.
 | 
				
			||||||
 | 
					            stopping_ids: The stopping ids to use.
 | 
				
			||||||
 | 
					                        Generation stops when these token IDs are predicted.
 | 
				
			||||||
 | 
					            tokenizer_kwargs: The kwargs to pass to the tokenizer.
 | 
				
			||||||
 | 
					            tokenizer_outputs_to_remove: The outputs to remove from the tokenizer.
 | 
				
			||||||
 | 
					                        Sometimes huggingface tokenizers return extra inputs that cause errors.
 | 
				
			||||||
 | 
					            model_kwargs: The kwargs to pass to the model during initialization.
 | 
				
			||||||
 | 
					            generate_kwargs: The kwargs to pass to the model during generation.
 | 
				
			||||||
 | 
					            is_chat_model: Whether the model is `chat`
 | 
				
			||||||
 | 
					            callback_manager: Callback manager.
 | 
				
			||||||
 | 
					            system_prompt: The system prompt, containing any extra instructions or context.
 | 
				
			||||||
 | 
					            messages_to_prompt: Function to convert messages to prompt.
 | 
				
			||||||
 | 
					            completion_to_prompt: Function to convert messages to prompt.
 | 
				
			||||||
 | 
					            pydantic_program_mode: DEFAULT.
 | 
				
			||||||
 | 
					            output_parser: BaseOutputParser.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            None.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        model_kwargs = model_kwargs or {}
 | 
				
			||||||
 | 
					        from bigdl.llm.transformers import AutoModelForCausalLM
 | 
				
			||||||
 | 
					        self._model = model or AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
 | 
					            model_name, load_in_4bit=True, **model_kwargs
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # check context_window
 | 
				
			||||||
 | 
					        config_dict = self._model.config.to_dict()
 | 
				
			||||||
 | 
					        model_context_window = int(
 | 
				
			||||||
 | 
					            config_dict.get("max_position_embeddings", context_window)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if model_context_window and model_context_window < context_window:
 | 
				
			||||||
 | 
					            logger.warning(
 | 
				
			||||||
 | 
					                f"Supplied context_window {context_window} is greater "
 | 
				
			||||||
 | 
					                f"than the model's max input size {model_context_window}. "
 | 
				
			||||||
 | 
					                "Disable this warning by setting a lower context_window."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            context_window = model_context_window
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        tokenizer_kwargs = tokenizer_kwargs or {}
 | 
				
			||||||
 | 
					        if "max_length" not in tokenizer_kwargs:
 | 
				
			||||||
 | 
					            tokenizer_kwargs["max_length"] = context_window
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(
 | 
				
			||||||
 | 
					            tokenizer_name, **tokenizer_kwargs
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if tokenizer_name != model_name:
 | 
				
			||||||
 | 
					            logger.warning(
 | 
				
			||||||
 | 
					                f"The model `{model_name}` and tokenizer `{tokenizer_name}` "
 | 
				
			||||||
 | 
					                f"are different, please ensure that they are compatible."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # setup stopping criteria
 | 
				
			||||||
 | 
					        stopping_ids_list = stopping_ids or []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        class StopOnTokens(StoppingCriteria):
 | 
				
			||||||
 | 
					            def __call__(
 | 
				
			||||||
 | 
					                self,
 | 
				
			||||||
 | 
					                input_ids: torch.LongTensor,
 | 
				
			||||||
 | 
					                scores: torch.FloatTensor,
 | 
				
			||||||
 | 
					                **kwargs: Any,
 | 
				
			||||||
 | 
					            ) -> bool:
 | 
				
			||||||
 | 
					                for stop_id in stopping_ids_list:
 | 
				
			||||||
 | 
					                    if input_ids[0][-1] == stop_id:
 | 
				
			||||||
 | 
					                        return True
 | 
				
			||||||
 | 
					                return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._stopping_criteria = StoppingCriteriaList([StopOnTokens()])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(query_wrapper_prompt, str):
 | 
				
			||||||
 | 
					            query_wrapper_prompt = PromptTemplate(query_wrapper_prompt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        messages_to_prompt = messages_to_prompt or self._tokenizer_messages_to_prompt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        super().__init__(
 | 
				
			||||||
 | 
					            context_window=context_window,
 | 
				
			||||||
 | 
					            max_new_tokens=max_new_tokens,
 | 
				
			||||||
 | 
					            query_wrapper_prompt=query_wrapper_prompt,
 | 
				
			||||||
 | 
					            tokenizer_name=tokenizer_name,
 | 
				
			||||||
 | 
					            model_name=model_name,
 | 
				
			||||||
 | 
					            device_map=device_map,
 | 
				
			||||||
 | 
					            stopping_ids=stopping_ids or [],
 | 
				
			||||||
 | 
					            tokenizer_kwargs=tokenizer_kwargs or {},
 | 
				
			||||||
 | 
					            tokenizer_outputs_to_remove=tokenizer_outputs_to_remove or [],
 | 
				
			||||||
 | 
					            model_kwargs=model_kwargs or {},
 | 
				
			||||||
 | 
					            generate_kwargs=generate_kwargs or {},
 | 
				
			||||||
 | 
					            is_chat_model=is_chat_model,
 | 
				
			||||||
 | 
					            callback_manager=callback_manager,
 | 
				
			||||||
 | 
					            system_prompt=system_prompt,
 | 
				
			||||||
 | 
					            messages_to_prompt=messages_to_prompt,
 | 
				
			||||||
 | 
					            completion_to_prompt=completion_to_prompt,
 | 
				
			||||||
 | 
					            pydantic_program_mode=pydantic_program_mode,
 | 
				
			||||||
 | 
					            output_parser=output_parser,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def class_name(cls) -> str:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Get class name.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            Str of class name.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return "BigDL_LLM"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def metadata(self) -> LLMMetadata:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Get meta data.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            LLMmetadata contains context_window,
 | 
				
			||||||
 | 
					            num_output, model_name, is_chat_model.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return LLMMetadata(
 | 
				
			||||||
 | 
					            context_window=self.context_window,
 | 
				
			||||||
 | 
					            num_output=self.max_new_tokens,
 | 
				
			||||||
 | 
					            model_name=self.model_name,
 | 
				
			||||||
 | 
					            is_chat_model=self.is_chat_model,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _tokenizer_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Use the tokenizer to convert messages to prompt. Fallback to generic.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            messages: Sequence of ChatMessage.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            Str of response.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if hasattr(self._tokenizer, "apply_chat_template"):
 | 
				
			||||||
 | 
					            messages_dict = [
 | 
				
			||||||
 | 
					                {"role": message.role.value, "content": message.content}
 | 
				
			||||||
 | 
					                for message in messages
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					            tokens = self._tokenizer.apply_chat_template(messages_dict)
 | 
				
			||||||
 | 
					            return self._tokenizer.decode(tokens)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return generic_messages_to_prompt(messages)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @llm_completion_callback()
 | 
				
			||||||
 | 
					    def complete(
 | 
				
			||||||
 | 
					        self, prompt: str, formatted: bool = False, **kwargs: Any
 | 
				
			||||||
 | 
					    ) -> CompletionResponse:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Complete by LLM.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            prompt: Prompt for completion.
 | 
				
			||||||
 | 
					            formatted: Whether the prompt is formatted by wrapper.
 | 
				
			||||||
 | 
					            kwargs: Other kwargs for complete.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            CompletionReponse after generation.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        full_prompt = prompt
 | 
				
			||||||
 | 
					        if not formatted:
 | 
				
			||||||
 | 
					            if self.query_wrapper_prompt:
 | 
				
			||||||
 | 
					                full_prompt = self.query_wrapper_prompt.format(query_str=prompt)
 | 
				
			||||||
 | 
					            if self.system_prompt:
 | 
				
			||||||
 | 
					                full_prompt = f"{self.system_prompt} {full_prompt}"
 | 
				
			||||||
 | 
					        input_ids = self._tokenizer(full_prompt, return_tensors="pt")
 | 
				
			||||||
 | 
					        input_ids = input_ids.to(self._model.device)
 | 
				
			||||||
 | 
					        # remove keys from the tokenizer if needed, to avoid HF errors
 | 
				
			||||||
 | 
					        for key in self.tokenizer_outputs_to_remove:
 | 
				
			||||||
 | 
					            if key in input_ids:
 | 
				
			||||||
 | 
					                input_ids.pop(key, None)
 | 
				
			||||||
 | 
					        tokens = self._model.generate(
 | 
				
			||||||
 | 
					            **input_ids,
 | 
				
			||||||
 | 
					            max_new_tokens=self.max_new_tokens,
 | 
				
			||||||
 | 
					            stopping_criteria=self._stopping_criteria,
 | 
				
			||||||
 | 
					            **self.generate_kwargs,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        completion_tokens = tokens[0][input_ids["input_ids"].size(1):]
 | 
				
			||||||
 | 
					        completion = self._tokenizer.decode(completion_tokens, skip_special_tokens=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return CompletionResponse(text=completion, raw={"model_output": tokens})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @llm_completion_callback()
 | 
				
			||||||
 | 
					    def stream_complete(
 | 
				
			||||||
 | 
					        self, prompt: str, formatted: bool = False, **kwargs: Any
 | 
				
			||||||
 | 
					    ) -> CompletionResponseGen:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Complete by LLM in stream.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            prompt: Prompt for completion.
 | 
				
			||||||
 | 
					            formatted: Whether the prompt is formatted by wrapper.
 | 
				
			||||||
 | 
					            kwargs: Other kwargs for complete.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            CompletionReponse after generation.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        from transformers import TextStreamer
 | 
				
			||||||
 | 
					        full_prompt = prompt
 | 
				
			||||||
 | 
					        if not formatted:
 | 
				
			||||||
 | 
					            if self.query_wrapper_prompt:
 | 
				
			||||||
 | 
					                full_prompt = self.query_wrapper_prompt.format(query_str=prompt)
 | 
				
			||||||
 | 
					            if self.system_prompt:
 | 
				
			||||||
 | 
					                full_prompt = f"{self.system_prompt} {full_prompt}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        input_ids = self._tokenizer.encode(full_prompt, return_tensors="pt")
 | 
				
			||||||
 | 
					        input_ids = input_ids.to(self._model.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for key in self.tokenizer_outputs_to_remove:
 | 
				
			||||||
 | 
					            if key in input_ids:
 | 
				
			||||||
 | 
					                input_ids.pop(key, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        streamer = TextStreamer(self._tokenizer, skip_prompt=True, skip_special_tokens=True)
 | 
				
			||||||
 | 
					        generation_kwargs = dict(
 | 
				
			||||||
 | 
					            input_ids,
 | 
				
			||||||
 | 
					            streamer=streamer,
 | 
				
			||||||
 | 
					            max_new_tokens=self.max_new_tokens,
 | 
				
			||||||
 | 
					            stopping_criteria=self._stopping_criteria,
 | 
				
			||||||
 | 
					            **self.generate_kwargs,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
 | 
				
			||||||
 | 
					        thread.start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # create generator based off of streamer
 | 
				
			||||||
 | 
					        def gen() -> CompletionResponseGen:
 | 
				
			||||||
 | 
					            text = ""
 | 
				
			||||||
 | 
					            for x in streamer:
 | 
				
			||||||
 | 
					                text += x
 | 
				
			||||||
 | 
					                yield CompletionResponse(text=text, delta=x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return gen()
 | 
				
			||||||
							
								
								
									
										88
									
								
								python/llm/test/llamaindex/test_llamaindex.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								python/llm/test/llamaindex/test_llamaindex.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,88 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from bigdl.llm.langchain.llms import TransformersLLM, TransformersPipelineLLM, \
 | 
				
			||||||
 | 
					    LlamaLLM, BloomLLM
 | 
				
			||||||
 | 
					from bigdl.llm.langchain.embeddings import TransformersEmbeddings, LlamaEmbeddings, \
 | 
				
			||||||
 | 
					    BloomEmbeddings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from langchain.document_loaders import WebBaseLoader
 | 
				
			||||||
 | 
					from langchain.indexes import VectorstoreIndexCreator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from langchain.chains.question_answering import load_qa_chain
 | 
				
			||||||
 | 
					from langchain.chains.chat_vector_db.prompts import (CONDENSE_QUESTION_PROMPT,
 | 
				
			||||||
 | 
					                                                     QA_PROMPT)
 | 
				
			||||||
 | 
					from langchain.text_splitter import CharacterTextSplitter
 | 
				
			||||||
 | 
					from langchain.vectorstores import Chroma
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					from unittest import TestCase
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					from bigdl.llm.llamaindex.llms import BigdlLLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Test_LlamaIndex_Transformers_API(TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.auto_model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
 | 
				
			||||||
 | 
					        self.auto_causal_model_path = os.environ.get('ORIGINAL_REPLIT_CODE_PATH')
 | 
				
			||||||
 | 
					        self.llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH')
 | 
				
			||||||
 | 
					        self.bloom_model_path = os.environ.get('BLOOM_ORIGIN_PATH')
 | 
				
			||||||
 | 
					        thread_num = os.environ.get('THREAD_NUM')
 | 
				
			||||||
 | 
					        if thread_num is not None:
 | 
				
			||||||
 | 
					            self.n_threads = int(thread_num)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.n_threads = 2   
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					    def completion_to_prompt(completion):
 | 
				
			||||||
 | 
					        return f"<|system|>\n</s>\n<|user|>\n{completion}</s>\n<|assistant|>\n" 
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def messages_to_prompt(messages):
 | 
				
			||||||
 | 
					        prompt = ""
 | 
				
			||||||
 | 
					        for message in messages:
 | 
				
			||||||
 | 
					            if message.role == "system":
 | 
				
			||||||
 | 
					                prompt += f"<|system|>\n{message.content}</s>\n"
 | 
				
			||||||
 | 
					            elif message.role == "user":
 | 
				
			||||||
 | 
					                prompt += f"<|user|>\n{message.content}</s>\n"
 | 
				
			||||||
 | 
					            elif message.role == "assistant":
 | 
				
			||||||
 | 
					                prompt += f"<|assistant|>\n{message.content}</s>\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # ensure we start with a system prompt, insert blank if needed
 | 
				
			||||||
 | 
					        if not prompt.startswith("<|system|>\n"):
 | 
				
			||||||
 | 
					            prompt = "<|system|>\n</s>\n" + prompt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # add final assistant prompt
 | 
				
			||||||
 | 
					        prompt = prompt + "<|assistant|>\n"
 | 
				
			||||||
 | 
					        return prompt
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def test_bigdl_llm(self):    
 | 
				
			||||||
 | 
					        llm = BigdlLLM(
 | 
				
			||||||
 | 
					            model_name=self.llama_model_path,
 | 
				
			||||||
 | 
					            tokenizer_name=self.llama_model_path,
 | 
				
			||||||
 | 
					            context_window=512,
 | 
				
			||||||
 | 
					            max_new_tokens=32,
 | 
				
			||||||
 | 
					            model_kwargs={},
 | 
				
			||||||
 | 
					            generate_kwargs={"temperature": 0.7, "do_sample": False},
 | 
				
			||||||
 | 
					            messages_to_prompt=self.messages_to_prompt,
 | 
				
			||||||
 | 
					            completion_to_prompt=self.completion_to_prompt,
 | 
				
			||||||
 | 
					            device_map="cpu",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        res = llm.complete("What is AI?")
 | 
				
			||||||
 | 
					        assert res!=None
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    pytest.main([__file__])
 | 
				
			||||||
		Loading…
	
		Reference in a new issue