Add llamaindex gpu example (#10314)
* add llamaindex example * fix core dump * refine readme * add trouble shooting * refine readme --------- Co-authored-by: Ariadne <wyn2000330@126.com>
This commit is contained in:
		
							parent
							
								
									fc7f10cd12
								
							
						
					
					
						commit
						1e6f0c6f1a
					
				
					 3 changed files with 402 additions and 0 deletions
				
			
		
							
								
								
									
										151
									
								
								python/llm/example/GPU/LlamaIndex/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										151
									
								
								python/llm/example/GPU/LlamaIndex/README.md
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,151 @@
 | 
				
			||||||
 | 
					# LlamaIndex Examples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This folder contains examples showcasing how to use [**LlamaIndex**](https://github.com/run-llama/llama_index) with `bigdl-llm`.
 | 
				
			||||||
 | 
					> [**LlamaIndex**](https://github.com/run-llama/llama_index) is a data framework designed to improve large language models by providing tools for easier data ingestion, management, and application integration. 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Retrieval-Augmented Generation (RAG) Example
 | 
				
			||||||
 | 
					The RAG example ([rag.py](./rag.py)) is adapted from the [Official llama index RAG example](https://docs.llamaindex.ai/en/stable/examples/low_level/oss_ingestion_retrieval.html). This example builds a pipeline to ingest data (e.g. llama2 paper in pdf format) into a vector database (e.g. PostgreSQL), and then build a retrieval pipeline from that vector database. 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 1. Setting up Dependencies 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					* **Install LlamaIndex Packages**
 | 
				
			||||||
 | 
					    ```bash
 | 
				
			||||||
 | 
					    pip install llama-index-readers-file llama-index-vector-stores-postgres llama-index-embeddings-huggingface
 | 
				
			||||||
 | 
					    ```
 | 
				
			||||||
 | 
					* **Install Bigdl LLM**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Follow the instructions in [GPU Install Guide](https://bigdl.readthedocs.io/en/latest/doc/LLM/Overview/install_gpu.html) to install bigdl-llm.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					* **Database Setup (using PostgreSQL)**:
 | 
				
			||||||
 | 
					    * Installation: 
 | 
				
			||||||
 | 
					        ```bash
 | 
				
			||||||
 | 
					        sudo apt-get install postgresql-client
 | 
				
			||||||
 | 
					        sudo apt-get install postgresql
 | 
				
			||||||
 | 
					        ```
 | 
				
			||||||
 | 
					    * Initialization:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      Switch to the **postgres** user and launch **psql** console:
 | 
				
			||||||
 | 
					        ```bash
 | 
				
			||||||
 | 
					        sudo su - postgres
 | 
				
			||||||
 | 
					        psql
 | 
				
			||||||
 | 
					        ```
 | 
				
			||||||
 | 
					      Then, create a new user role:
 | 
				
			||||||
 | 
					        ```bash
 | 
				
			||||||
 | 
					        CREATE ROLE <user> WITH LOGIN PASSWORD '<password>';
 | 
				
			||||||
 | 
					        ALTER ROLE <user> SUPERUSER;    
 | 
				
			||||||
 | 
					        ```
 | 
				
			||||||
 | 
					* **Pgvector Installation**:
 | 
				
			||||||
 | 
					    Follow installation instructions on [pgvector's GitHub](https://github.com/pgvector/pgvector) and refer to the [installation notes](https://github.com/pgvector/pgvector#installation-notes) for additional help.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					* **Data Preparation**: Download the Llama2 paper and save it as `data/llama2.pdf`, which serves as the default source file for retrieval.
 | 
				
			||||||
 | 
					    ```bash
 | 
				
			||||||
 | 
					    mkdir data
 | 
				
			||||||
 | 
					    wget --user-agent "Mozilla" "https://arxiv.org/pdf/2307.09288.pdf" -O "data/llama2.pdf"
 | 
				
			||||||
 | 
					    ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 2. Configures OneAPI environment variables
 | 
				
			||||||
 | 
					#### 2.1 Configurations for Linux
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					source /opt/intel/oneapi/setvars.sh
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					#### 2.2 Configurations for Windows
 | 
				
			||||||
 | 
					```cmd
 | 
				
			||||||
 | 
					call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat"
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					> Note: Please make sure you are using **CMD** (**Anaconda Prompt** if using conda) to run the command as PowerShell is not supported.
 | 
				
			||||||
 | 
					### 3. Runtime Configurations
 | 
				
			||||||
 | 
					For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device.
 | 
				
			||||||
 | 
					#### 3.1 Configurations for Linux
 | 
				
			||||||
 | 
					<details>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<summary>For Intel Arc™ A-Series Graphics and Intel Data Center GPU Flex Series</summary>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					export USE_XETLA=OFF
 | 
				
			||||||
 | 
					export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					</details>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<details>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<summary>For Intel Data Center GPU Max Series</summary>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so
 | 
				
			||||||
 | 
					export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
 | 
				
			||||||
 | 
					export ENABLE_SDP_FUSION=1
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					> Note: Please note that `libtcmalloc.so` can be installed by `conda install -c conda-forge -y gperftools=2.10`.
 | 
				
			||||||
 | 
					</details>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#### 3.2 Configurations for Windows
 | 
				
			||||||
 | 
					<details>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<summary>For Intel iGPU</summary>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```cmd
 | 
				
			||||||
 | 
					set SYCL_CACHE_PERSISTENT=1
 | 
				
			||||||
 | 
					set BIGDL_LLM_XMX_DISABLED=1
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					</details>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<details>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<summary>For Intel Arc™ A300-Series or Pro A60</summary>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```cmd
 | 
				
			||||||
 | 
					set SYCL_CACHE_PERSISTENT=1
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					</details>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<details>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<summary>For other Intel dGPU Series</summary>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					There is no need to set further environment variables.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					</details>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					> Note: For the first time that each model runs on Intel iGPU/Intel Arc™ A300-Series or Pro A60, it may take several minutes to compile.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 4. Running the RAG example
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					In the current directory, run the example with command:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					python rag.py -m <path_to_model>
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					**Additional Parameters for Configuration**:
 | 
				
			||||||
 | 
					- `-m MODEL_PATH`: **Required**, path to the LLM model
 | 
				
			||||||
 | 
					- `-e EMBEDDING_MODEL_PATH`: path to the embedding model
 | 
				
			||||||
 | 
					- `-u USERNAME`: username in the PostgreSQL database
 | 
				
			||||||
 | 
					- `-p PASSWORD`: password in the PostgreSQL database
 | 
				
			||||||
 | 
					- `-q QUESTION`: question you want to ask
 | 
				
			||||||
 | 
					- `-d DATA`: path to source data used for retrieval (in pdf format)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 5. Example Output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					A query such as **"How does Llama 2 compare to other open-source models?"** with the Llama2 paper as the data source, using the `Llama-2-7b-chat-hf` model, will produce the output like below:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					The comparison between Llama 2 and other open-source models is complex and depends on various factors such as the specific benchmarks used, the model size, and the task at hand.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					In terms of performance on the benchmarks provided in the table, Llama 2 outperforms other open-source models on most categories. For example, on the MMLU benchmark, Llama 2 achieves a score of 22.5, while the next best open-source model, Poplar Aggregated Benchmarks, scores 17.5. Similarly, on the BBH benchmark, Llama 2 scores 20.5, while the next best open-source model scores 16.5.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					However, it's important to note that the performance of Llama 2 can vary depending on the specific task and dataset being used. For example, on the coding benchmarks, Llama 2 performs significantly worse than other open-source models, such as PaLM (540B) and GPT-4.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					In conclusion, while Llama 2 performs well on most benchmarks compared to other open-source models, its performance
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 6. Trouble shooting
 | 
				
			||||||
 | 
					#### 6.1 Core dump
 | 
				
			||||||
 | 
					If you encounter a core dump error in your Python code, it is crucial to verify that the `import torch` statement is placed at the top of your Python file, just as what we did in `rag.py`.
 | 
				
			||||||
							
								
								
									
										248
									
								
								python/llm/example/GPU/LlamaIndex/rag.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										248
									
								
								python/llm/example/GPU/LlamaIndex/rag.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,248 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from llama_index.embeddings.huggingface import HuggingFaceEmbedding
 | 
				
			||||||
 | 
					from sqlalchemy import make_url
 | 
				
			||||||
 | 
					from llama_index.vector_stores.postgres import PGVectorStore
 | 
				
			||||||
 | 
					# from llama_index.llms.llama_cpp import LlamaCPP
 | 
				
			||||||
 | 
					import psycopg2
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					from llama_index.readers.file import PyMuPDFReader
 | 
				
			||||||
 | 
					from llama_index.core.schema import NodeWithScore
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					from llama_index.core.query_engine import RetrieverQueryEngine
 | 
				
			||||||
 | 
					from llama_index.core import QueryBundle
 | 
				
			||||||
 | 
					from llama_index.core.retrievers import BaseRetriever
 | 
				
			||||||
 | 
					from typing import Any, List
 | 
				
			||||||
 | 
					from llama_index.core.node_parser import SentenceSplitter
 | 
				
			||||||
 | 
					from llama_index.core.vector_stores import VectorStoreQuery
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_vector_database(username, password):
 | 
				
			||||||
 | 
					    db_name = "example_db"
 | 
				
			||||||
 | 
					    host = "localhost"
 | 
				
			||||||
 | 
					    password = password
 | 
				
			||||||
 | 
					    port = "5432"
 | 
				
			||||||
 | 
					    user = username
 | 
				
			||||||
 | 
					    # conn = psycopg2.connect(connection_string)
 | 
				
			||||||
 | 
					    conn = psycopg2.connect(
 | 
				
			||||||
 | 
					        dbname="postgres",
 | 
				
			||||||
 | 
					        host=host,
 | 
				
			||||||
 | 
					        password=password,
 | 
				
			||||||
 | 
					        port=port,
 | 
				
			||||||
 | 
					        user=user,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    conn.autocommit = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with conn.cursor() as c:
 | 
				
			||||||
 | 
					        c.execute(f"DROP DATABASE IF EXISTS {db_name}")
 | 
				
			||||||
 | 
					        c.execute(f"CREATE DATABASE {db_name}")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    vector_store = PGVectorStore.from_params(
 | 
				
			||||||
 | 
					        database=db_name,
 | 
				
			||||||
 | 
					        host=host,
 | 
				
			||||||
 | 
					        password=password,
 | 
				
			||||||
 | 
					        port=port,
 | 
				
			||||||
 | 
					        user=user,
 | 
				
			||||||
 | 
					        table_name="llama2_paper",
 | 
				
			||||||
 | 
					        embed_dim=384,  # openai embedding dimension
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    return vector_store
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_data(data_path):
 | 
				
			||||||
 | 
					    loader = PyMuPDFReader()
 | 
				
			||||||
 | 
					    documents = loader.load(file_path=data_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    text_parser = SentenceSplitter(
 | 
				
			||||||
 | 
					        chunk_size=1024,
 | 
				
			||||||
 | 
					        # separator=" ",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    text_chunks = []
 | 
				
			||||||
 | 
					    # maintain relationship with source doc index, to help inject doc metadata in (3)
 | 
				
			||||||
 | 
					    doc_idxs = []
 | 
				
			||||||
 | 
					    for doc_idx, doc in enumerate(documents):
 | 
				
			||||||
 | 
					        cur_text_chunks = text_parser.split_text(doc.text)
 | 
				
			||||||
 | 
					        text_chunks.extend(cur_text_chunks)
 | 
				
			||||||
 | 
					        doc_idxs.extend([doc_idx] * len(cur_text_chunks))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from llama_index.core.schema import TextNode
 | 
				
			||||||
 | 
					    nodes = []
 | 
				
			||||||
 | 
					    for idx, text_chunk in enumerate(text_chunks):
 | 
				
			||||||
 | 
					        node = TextNode(
 | 
				
			||||||
 | 
					            text=text_chunk,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        src_doc = documents[doc_idxs[idx]]
 | 
				
			||||||
 | 
					        node.metadata = src_doc.metadata
 | 
				
			||||||
 | 
					        nodes.append(node)
 | 
				
			||||||
 | 
					    return nodes
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VectorDBRetriever(BaseRetriever):
 | 
				
			||||||
 | 
					    """Retriever over a postgres vector store."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        vector_store: PGVectorStore,
 | 
				
			||||||
 | 
					        embed_model: Any,
 | 
				
			||||||
 | 
					        query_mode: str = "default",
 | 
				
			||||||
 | 
					        similarity_top_k: int = 2,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        """Init params."""
 | 
				
			||||||
 | 
					        self._vector_store = vector_store
 | 
				
			||||||
 | 
					        self._embed_model = embed_model
 | 
				
			||||||
 | 
					        self._query_mode = query_mode
 | 
				
			||||||
 | 
					        self._similarity_top_k = similarity_top_k
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
 | 
				
			||||||
 | 
					        """Retrieve."""
 | 
				
			||||||
 | 
					        query_embedding = self._embed_model.get_query_embedding(
 | 
				
			||||||
 | 
					            query_bundle.query_str
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        vector_store_query = VectorStoreQuery(
 | 
				
			||||||
 | 
					            query_embedding=query_embedding,
 | 
				
			||||||
 | 
					            similarity_top_k=self._similarity_top_k,
 | 
				
			||||||
 | 
					            mode=self._query_mode,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        query_result = self._vector_store.query(vector_store_query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        nodes_with_scores = []
 | 
				
			||||||
 | 
					        for index, node in enumerate(query_result.nodes):
 | 
				
			||||||
 | 
					            score: Optional[float] = None
 | 
				
			||||||
 | 
					            if query_result.similarities is not None:
 | 
				
			||||||
 | 
					                score = query_result.similarities[index]
 | 
				
			||||||
 | 
					            nodes_with_scores.append(NodeWithScore(node=node, score=score))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return nodes_with_scores
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					def completion_to_prompt(completion):
 | 
				
			||||||
 | 
					    return f"<|system|>\n</s>\n<|user|>\n{completion}</s>\n<|assistant|>\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Transform a list of chat messages into zephyr-specific input
 | 
				
			||||||
 | 
					def messages_to_prompt(messages):
 | 
				
			||||||
 | 
					    prompt = ""
 | 
				
			||||||
 | 
					    for message in messages:
 | 
				
			||||||
 | 
					        if message.role == "system":
 | 
				
			||||||
 | 
					            prompt += f"<|system|>\n{message.content}</s>\n"
 | 
				
			||||||
 | 
					        elif message.role == "user":
 | 
				
			||||||
 | 
					            prompt += f"<|user|>\n{message.content}</s>\n"
 | 
				
			||||||
 | 
					        elif message.role == "assistant":
 | 
				
			||||||
 | 
					            prompt += f"<|assistant|>\n{message.content}</s>\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ensure we start with a system prompt, insert blank if needed
 | 
				
			||||||
 | 
					    if not prompt.startswith("<|system|>\n"):
 | 
				
			||||||
 | 
					        prompt = "<|system|>\n</s>\n" + prompt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # add final assistant prompt
 | 
				
			||||||
 | 
					    prompt = prompt + "<|assistant|>\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return prompt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main(args):
 | 
				
			||||||
 | 
					    embed_model = HuggingFaceEmbedding(model_name=args.embedding_model_path)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Use custom LLM in BigDL
 | 
				
			||||||
 | 
					    from bigdl.llm.llamaindex.llms import BigdlLLM
 | 
				
			||||||
 | 
					    llm = BigdlLLM(
 | 
				
			||||||
 | 
					        model_name=args.model_path,
 | 
				
			||||||
 | 
					        tokenizer_name=args.model_path,
 | 
				
			||||||
 | 
					        context_window=512,
 | 
				
			||||||
 | 
					        max_new_tokens=32,
 | 
				
			||||||
 | 
					        generate_kwargs={"temperature": 0.7, "do_sample": False},
 | 
				
			||||||
 | 
					        model_kwargs={},
 | 
				
			||||||
 | 
					        messages_to_prompt=messages_to_prompt,
 | 
				
			||||||
 | 
					        completion_to_prompt=completion_to_prompt,
 | 
				
			||||||
 | 
					        device_map="xpu",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    vector_store = load_vector_database(username=args.user, password=args.password)
 | 
				
			||||||
 | 
					    nodes = load_data(data_path=args.data)
 | 
				
			||||||
 | 
					    for node in nodes:
 | 
				
			||||||
 | 
					        node_embedding = embed_model.get_text_embedding(
 | 
				
			||||||
 | 
					            node.get_content(metadata_mode="all")
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        node.embedding = node_embedding
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    vector_store.add(nodes)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # query_str = "Can you tell me about the key concepts for safety finetuning"
 | 
				
			||||||
 | 
					    query_str = "Explain about the training data for Llama 2"
 | 
				
			||||||
 | 
					    query_embedding = embed_model.get_query_embedding(query_str)
 | 
				
			||||||
 | 
					    # construct vector store query
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    query_mode = "default"
 | 
				
			||||||
 | 
					    # query_mode = "sparse"
 | 
				
			||||||
 | 
					    # query_mode = "hybrid"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vector_store_query = VectorStoreQuery(
 | 
				
			||||||
 | 
					        query_embedding=query_embedding, similarity_top_k=2, mode=query_mode
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    # returns a VectorStoreQueryResult
 | 
				
			||||||
 | 
					    query_result = vector_store.query(vector_store_query)
 | 
				
			||||||
 | 
					    # print("Retrieval Results: ")
 | 
				
			||||||
 | 
					    # print(query_result.nodes[0].get_content())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    nodes_with_scores = []
 | 
				
			||||||
 | 
					    for index, node in enumerate(query_result.nodes):
 | 
				
			||||||
 | 
					        score: Optional[float] = None
 | 
				
			||||||
 | 
					        if query_result.similarities is not None:
 | 
				
			||||||
 | 
					            score = query_result.similarities[index]
 | 
				
			||||||
 | 
					        nodes_with_scores.append(NodeWithScore(node=node, score=score))
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    retriever = VectorDBRetriever(
 | 
				
			||||||
 | 
					        vector_store, embed_model, query_mode="default", similarity_top_k=1
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    query_engine = RetrieverQueryEngine.from_args(retriever, llm=llm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # query_str = "How does Llama 2 perform compared to other open-source models?"
 | 
				
			||||||
 | 
					    query_str = args.question
 | 
				
			||||||
 | 
					    response = query_engine.query(query_str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print("------------RESPONSE GENERATION---------------------")
 | 
				
			||||||
 | 
					    print(str(response))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser(description='LlamaIndex BigdlLLM Example')
 | 
				
			||||||
 | 
					    parser.add_argument('-m','--model-path', type=str, required=True,
 | 
				
			||||||
 | 
					                        help='the path to transformers model')
 | 
				
			||||||
 | 
					    parser.add_argument('-q', '--question', type=str, default='How does Llama 2 perform compared to other open-source models?',
 | 
				
			||||||
 | 
					                        help='qustion you want to ask.')
 | 
				
			||||||
 | 
					    parser.add_argument('-d','--data',type=str, default='./data/llama2.pdf',
 | 
				
			||||||
 | 
					                        help="the data used during retrieval")
 | 
				
			||||||
 | 
					    parser.add_argument('-u', '--user', type=str, required=True,
 | 
				
			||||||
 | 
					                        help="user name in the database postgres")
 | 
				
			||||||
 | 
					    parser.add_argument('-p','--password', type=str, required=True,
 | 
				
			||||||
 | 
					                        help="the password of the user in the database")
 | 
				
			||||||
 | 
					    parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en",
 | 
				
			||||||
 | 
					                        help="the path to embedding model path")
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    main(args)
 | 
				
			||||||
| 
						 | 
					@ -239,6 +239,9 @@ class BigdlLLM(CustomLLM):
 | 
				
			||||||
            model_name, load_in_4bit=True, **model_kwargs
 | 
					            model_name, load_in_4bit=True, **model_kwargs
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if 'xpu' in device_map:
 | 
				
			||||||
 | 
					            self._model = self._model.to(device_map)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # check context_window
 | 
					        # check context_window
 | 
				
			||||||
        config_dict = self._model.config.to_dict()
 | 
					        config_dict = self._model.config.to_dict()
 | 
				
			||||||
        model_context_window = int(
 | 
					        model_context_window = int(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue