Add LlamaIndex RAG (#10263)
* run demo * format code * add llamaindex * add custom LLM with bigdl * update * add readme * begin ut * add unit test * add license * add license * revised * update * modify docs * remove data folder * update * modify prompt * fixed * fixed * fixed
This commit is contained in:
parent
5d7243067c
commit
4e6cc424f1
6 changed files with 898 additions and 0 deletions
60
python/llm/example/CPU/LlamaIndex/README.md
Normal file
60
python/llm/example/CPU/LlamaIndex/README.md
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
# LlamaIndex Examples
|
||||||
|
|
||||||
|
The examples here show how to use LlamaIndex with `bigdl-llm`.
|
||||||
|
The RAG example is modified from the [demo](https://docs.llamaindex.ai/en/stable/examples/low_level/oss_ingestion_retrieval.html).
|
||||||
|
|
||||||
|
## Install bigdl-llm
|
||||||
|
Follow the instructions in [Install](https://github.com/intel-analytics/BigDL/tree/main/python/llm#install).
|
||||||
|
|
||||||
|
## Install Required Dependencies for llamaindex examples.
|
||||||
|
|
||||||
|
### Install Site-packages
|
||||||
|
```bash
|
||||||
|
pip install llama-index-readers-file
|
||||||
|
pip install llama-index-vector-stores-postgres
|
||||||
|
pip install llama-index-embeddings-huggingface
|
||||||
|
```
|
||||||
|
|
||||||
|
### Install Postgres
|
||||||
|
> Note: There are plenty of open-source databases you can use. Here we provide an example using Postgres.
|
||||||
|
* Download and install postgres by running the commands below.
|
||||||
|
```bash
|
||||||
|
sudo apt-get install postgresql-client
|
||||||
|
sudo apt-get install postgresql
|
||||||
|
```
|
||||||
|
* Initilize postgres.
|
||||||
|
```bash
|
||||||
|
sudo su - postgres
|
||||||
|
psql
|
||||||
|
```
|
||||||
|
After running the commands in the shell, we reach the console of postgres. Then we can add a role like the following
|
||||||
|
```bash
|
||||||
|
CREATE ROLE <user> WITH LOGIN PASSWORD '<password>';
|
||||||
|
ALTER ROLE <user> SUPERUSER;
|
||||||
|
```
|
||||||
|
* Install pgvector according to the [page](https://github.com/pgvector/pgvector). If you encounter problem about the installation, please refer to the [notes](https://github.com/pgvector/pgvector#installation-notes) which may be helpful.
|
||||||
|
* Download the database.
|
||||||
|
```bash
|
||||||
|
mkdir data
|
||||||
|
wget --user-agent "Mozilla" "https://arxiv.org/pdf/2307.09288.pdf" -O "data/llama2.pdf"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Run the examples
|
||||||
|
|
||||||
|
### Retrieval-augmented Generation
|
||||||
|
```bash
|
||||||
|
python rag.py -m MODEL_PATH -e EMBEDDING_MODEL_PATH -u USERNAME -p PASSWORD -q QUESTION -d DATA
|
||||||
|
```
|
||||||
|
arguments info:
|
||||||
|
- `-m MODEL_PATH`: **required**, path to the llama model
|
||||||
|
- `-e EMBEDDING_MODEL_PATH`: path to the embedding model
|
||||||
|
- `-u USERNAME`: username in the postgres database
|
||||||
|
- `-p PASSWORD`: password in the postgres database
|
||||||
|
- `-q QUESTION`: question you want to ask
|
||||||
|
- `-d DATA`: path to data used during retrieval
|
||||||
|
|
||||||
|
Here is the sample output when applying Llama-2-7b-chat-hf as the generatio model when we ask "How does Llama 2 perform compared to other open-source models?" and use llama.pdf as database.
|
||||||
|
```
|
||||||
|
Llama 2 performs better than most open-source models on the benchmarks we tested. Specifically, it outperforms all open-source models on MMLU and BBH, and is close to GPT-3.5 on these benchmarks. Additionally, Llama 2 is on par or better than PaLM-2-L on almost all benchmarks. The only exception is the coding benchmarks, where Llama 2 lags significantly behind GPT-4 and PaLM-2-L. Overall, Llama 2 demonstrates strong performance on a wide range of natural language processing tasks.
|
||||||
|
```
|
||||||
247
python/llm/example/CPU/LlamaIndex/rag.py
Normal file
247
python/llm/example/CPU/LlamaIndex/rag.py
Normal file
|
|
@ -0,0 +1,247 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
from sqlalchemy import make_url
|
||||||
|
from llama_index.vector_stores.postgres import PGVectorStore
|
||||||
|
# from llama_index.llms.llama_cpp import LlamaCPP
|
||||||
|
import psycopg2
|
||||||
|
from pathlib import Path
|
||||||
|
from llama_index.readers.file import PyMuPDFReader
|
||||||
|
from llama_index.core.schema import NodeWithScore
|
||||||
|
from typing import Optional
|
||||||
|
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||||
|
from llama_index.core import QueryBundle
|
||||||
|
from llama_index.core.retrievers import BaseRetriever
|
||||||
|
from typing import Any, List
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
from llama_index.core.vector_stores import VectorStoreQuery
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def load_vector_database(username, password):
|
||||||
|
db_name = "example_db"
|
||||||
|
host = "localhost"
|
||||||
|
password = password
|
||||||
|
port = "5432"
|
||||||
|
user = username
|
||||||
|
# conn = psycopg2.connect(connection_string)
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
dbname="postgres",
|
||||||
|
host=host,
|
||||||
|
password=password,
|
||||||
|
port=port,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
conn.autocommit = True
|
||||||
|
|
||||||
|
with conn.cursor() as c:
|
||||||
|
c.execute(f"DROP DATABASE IF EXISTS {db_name}")
|
||||||
|
c.execute(f"CREATE DATABASE {db_name}")
|
||||||
|
|
||||||
|
vector_store = PGVectorStore.from_params(
|
||||||
|
database=db_name,
|
||||||
|
host=host,
|
||||||
|
password=password,
|
||||||
|
port=port,
|
||||||
|
user=user,
|
||||||
|
table_name="llama2_paper",
|
||||||
|
embed_dim=384, # openai embedding dimension
|
||||||
|
)
|
||||||
|
return vector_store
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(data_path):
|
||||||
|
loader = PyMuPDFReader()
|
||||||
|
documents = loader.load(file_path=data_path)
|
||||||
|
|
||||||
|
|
||||||
|
text_parser = SentenceSplitter(
|
||||||
|
chunk_size=1024,
|
||||||
|
# separator=" ",
|
||||||
|
)
|
||||||
|
text_chunks = []
|
||||||
|
# maintain relationship with source doc index, to help inject doc metadata in (3)
|
||||||
|
doc_idxs = []
|
||||||
|
for doc_idx, doc in enumerate(documents):
|
||||||
|
cur_text_chunks = text_parser.split_text(doc.text)
|
||||||
|
text_chunks.extend(cur_text_chunks)
|
||||||
|
doc_idxs.extend([doc_idx] * len(cur_text_chunks))
|
||||||
|
|
||||||
|
from llama_index.core.schema import TextNode
|
||||||
|
nodes = []
|
||||||
|
for idx, text_chunk in enumerate(text_chunks):
|
||||||
|
node = TextNode(
|
||||||
|
text=text_chunk,
|
||||||
|
)
|
||||||
|
src_doc = documents[doc_idxs[idx]]
|
||||||
|
node.metadata = src_doc.metadata
|
||||||
|
nodes.append(node)
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDBRetriever(BaseRetriever):
|
||||||
|
"""Retriever over a postgres vector store."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vector_store: PGVectorStore,
|
||||||
|
embed_model: Any,
|
||||||
|
query_mode: str = "default",
|
||||||
|
similarity_top_k: int = 2,
|
||||||
|
) -> None:
|
||||||
|
"""Init params."""
|
||||||
|
self._vector_store = vector_store
|
||||||
|
self._embed_model = embed_model
|
||||||
|
self._query_mode = query_mode
|
||||||
|
self._similarity_top_k = similarity_top_k
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||||
|
"""Retrieve."""
|
||||||
|
query_embedding = self._embed_model.get_query_embedding(
|
||||||
|
query_bundle.query_str
|
||||||
|
)
|
||||||
|
vector_store_query = VectorStoreQuery(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
similarity_top_k=self._similarity_top_k,
|
||||||
|
mode=self._query_mode,
|
||||||
|
)
|
||||||
|
query_result = self._vector_store.query(vector_store_query)
|
||||||
|
|
||||||
|
nodes_with_scores = []
|
||||||
|
for index, node in enumerate(query_result.nodes):
|
||||||
|
score: Optional[float] = None
|
||||||
|
if query_result.similarities is not None:
|
||||||
|
score = query_result.similarities[index]
|
||||||
|
nodes_with_scores.append(NodeWithScore(node=node, score=score))
|
||||||
|
|
||||||
|
return nodes_with_scores
|
||||||
|
|
||||||
|
def completion_to_prompt(completion):
|
||||||
|
return f"<|system|>\n</s>\n<|user|>\n{completion}</s>\n<|assistant|>\n"
|
||||||
|
|
||||||
|
|
||||||
|
# Transform a list of chat messages into zephyr-specific input
|
||||||
|
def messages_to_prompt(messages):
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
if message.role == "system":
|
||||||
|
prompt += f"<|system|>\n{message.content}</s>\n"
|
||||||
|
elif message.role == "user":
|
||||||
|
prompt += f"<|user|>\n{message.content}</s>\n"
|
||||||
|
elif message.role == "assistant":
|
||||||
|
prompt += f"<|assistant|>\n{message.content}</s>\n"
|
||||||
|
|
||||||
|
# ensure we start with a system prompt, insert blank if needed
|
||||||
|
if not prompt.startswith("<|system|>\n"):
|
||||||
|
prompt = "<|system|>\n</s>\n" + prompt
|
||||||
|
|
||||||
|
# add final assistant prompt
|
||||||
|
prompt = prompt + "<|assistant|>\n"
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
embed_model = HuggingFaceEmbedding(model_name=args.embedding_model_path)
|
||||||
|
|
||||||
|
# Use custom LLM in BigDL
|
||||||
|
from bigdl.llm.llamaindex.llms import BigdlLLM
|
||||||
|
llm = BigdlLLM(
|
||||||
|
model_name=args.model_path,
|
||||||
|
tokenizer_name=args.model_path,
|
||||||
|
context_window=512,
|
||||||
|
max_new_tokens=32,
|
||||||
|
generate_kwargs={"temperature": 0.7, "do_sample": False},
|
||||||
|
model_kwargs={},
|
||||||
|
messages_to_prompt=messages_to_prompt,
|
||||||
|
completion_to_prompt=completion_to_prompt,
|
||||||
|
device_map="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_store = load_vector_database(username=args.user, password=args.password)
|
||||||
|
nodes = load_data(data_path=args.data)
|
||||||
|
for node in nodes:
|
||||||
|
node_embedding = embed_model.get_text_embedding(
|
||||||
|
node.get_content(metadata_mode="all")
|
||||||
|
)
|
||||||
|
node.embedding = node_embedding
|
||||||
|
|
||||||
|
vector_store.add(nodes)
|
||||||
|
|
||||||
|
# query_str = "Can you tell me about the key concepts for safety finetuning"
|
||||||
|
query_str = "Explain about the training data for Llama 2"
|
||||||
|
query_embedding = embed_model.get_query_embedding(query_str)
|
||||||
|
# construct vector store query
|
||||||
|
|
||||||
|
|
||||||
|
query_mode = "default"
|
||||||
|
# query_mode = "sparse"
|
||||||
|
# query_mode = "hybrid"
|
||||||
|
|
||||||
|
vector_store_query = VectorStoreQuery(
|
||||||
|
query_embedding=query_embedding, similarity_top_k=2, mode=query_mode
|
||||||
|
)
|
||||||
|
# returns a VectorStoreQueryResult
|
||||||
|
query_result = vector_store.query(vector_store_query)
|
||||||
|
# print("Retrieval Results: ")
|
||||||
|
# print(query_result.nodes[0].get_content())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
nodes_with_scores = []
|
||||||
|
for index, node in enumerate(query_result.nodes):
|
||||||
|
score: Optional[float] = None
|
||||||
|
if query_result.similarities is not None:
|
||||||
|
score = query_result.similarities[index]
|
||||||
|
nodes_with_scores.append(NodeWithScore(node=node, score=score))
|
||||||
|
|
||||||
|
retriever = VectorDBRetriever(
|
||||||
|
vector_store, embed_model, query_mode="default", similarity_top_k=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
query_engine = RetrieverQueryEngine.from_args(retriever, llm=llm)
|
||||||
|
|
||||||
|
# query_str = "How does Llama 2 perform compared to other open-source models?"
|
||||||
|
query_str = args.question
|
||||||
|
response = query_engine.query(query_str)
|
||||||
|
|
||||||
|
|
||||||
|
print("------------RESPONSE GENERATION---------------------")
|
||||||
|
print(str(response))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description='LlamaIndex BigdlLLM Example')
|
||||||
|
parser.add_argument('-m','--model-path', type=str, required=True,
|
||||||
|
help='the path to transformers model')
|
||||||
|
parser.add_argument('-q', '--question', type=str, default='How does Llama 2 perform compared to other open-source models?',
|
||||||
|
help='qustion you want to ask.')
|
||||||
|
parser.add_argument('-d','--data',type=str, default='./data/llama2.pdf',
|
||||||
|
help="the data used during retrieval")
|
||||||
|
parser.add_argument('-u', '--user', type=str, required=True,
|
||||||
|
help="user name in the database postgres")
|
||||||
|
parser.add_argument('-p','--password', type=str, required=True,
|
||||||
|
help="the password of the user in the database")
|
||||||
|
parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en",
|
||||||
|
help="the path to embedding model path")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
20
python/llm/src/bigdl/llm/llamaindex/__init__.py
Normal file
20
python/llm/src/bigdl/llm/llamaindex/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||||
|
# physically located elsewhere.
|
||||||
|
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||||
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
34
python/llm/src/bigdl/llm/llamaindex/llms/__init__.py
Normal file
34
python/llm/src/bigdl/llm/llamaindex/llms/__init__.py
Normal file
|
|
@ -0,0 +1,34 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||||
|
# physically located elsewhere.
|
||||||
|
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||||
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
|
"""Wrappers on top of large language models APIs."""
|
||||||
|
from typing import Dict, Type
|
||||||
|
|
||||||
|
from .bigdlllm import *
|
||||||
|
from llama_index.core.base.llms.base import BaseLLM
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BigdlLLM",
|
||||||
|
]
|
||||||
|
|
||||||
|
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||||
|
"BigdlLLM": BigdlLLM,
|
||||||
|
}
|
||||||
449
python/llm/src/bigdl/llm/llamaindex/llms/bigdlllm.py
Normal file
449
python/llm/src/bigdl/llm/llamaindex/llms/bigdlllm.py
Normal file
|
|
@ -0,0 +1,449 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# The file is modified from
|
||||||
|
# https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/base/llms/base.py
|
||||||
|
|
||||||
|
# The MIT License
|
||||||
|
|
||||||
|
# Copyright (c) Harrison Chase
|
||||||
|
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
# THE SOFTWARE.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from threading import Thread
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import AsyncInferenceClient, InferenceClient, model_info
|
||||||
|
from llama_index.core.base.llms.types import (
|
||||||
|
ChatMessage,
|
||||||
|
ChatResponse,
|
||||||
|
ChatResponseAsyncGen,
|
||||||
|
ChatResponseGen,
|
||||||
|
CompletionResponse,
|
||||||
|
CompletionResponseAsyncGen,
|
||||||
|
CompletionResponseGen,
|
||||||
|
LLMMetadata,
|
||||||
|
MessageRole,
|
||||||
|
)
|
||||||
|
from llama_index.core.bridge.pydantic import Field, PrivateAttr
|
||||||
|
from llama_index.core.callbacks import CallbackManager
|
||||||
|
from llama_index.core.constants import (
|
||||||
|
DEFAULT_CONTEXT_WINDOW,
|
||||||
|
DEFAULT_NUM_OUTPUTS,
|
||||||
|
)
|
||||||
|
from llama_index.core.llms.callbacks import (
|
||||||
|
llm_chat_callback,
|
||||||
|
llm_completion_callback,
|
||||||
|
)
|
||||||
|
from llama_index.core.llms.custom import CustomLLM
|
||||||
|
|
||||||
|
from llama_index.core.base.llms.generic_utils import (
|
||||||
|
messages_to_prompt as generic_messages_to_prompt,
|
||||||
|
)
|
||||||
|
from llama_index.core.prompts.base import PromptTemplate
|
||||||
|
from llama_index.core.types import BaseOutputParser, PydanticProgramMode
|
||||||
|
from transformers import (
|
||||||
|
StoppingCriteria,
|
||||||
|
StoppingCriteriaList,
|
||||||
|
)
|
||||||
|
from transformers import AutoTokenizer, LlamaTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_HUGGINGFACE_MODEL = "meta-llama/Llama-2-7b-chat-hf"
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BigdlLLM(CustomLLM):
|
||||||
|
"""Wrapper around the BigDL-LLM
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from bigdl.llm.llamaindex.llms import BigdlLLM
|
||||||
|
llm = BigdlLLM(model_path="/path/to/llama/model")
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_name: str = Field(
|
||||||
|
default=DEFAULT_HUGGINGFACE_MODEL,
|
||||||
|
description=(
|
||||||
|
"The model name to use from HuggingFace. "
|
||||||
|
"Unused if `model` is passed in directly."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
context_window: int = Field(
|
||||||
|
default=DEFAULT_CONTEXT_WINDOW,
|
||||||
|
description="The maximum number of tokens available for input.",
|
||||||
|
gt=0,
|
||||||
|
)
|
||||||
|
max_new_tokens: int = Field(
|
||||||
|
default=DEFAULT_NUM_OUTPUTS,
|
||||||
|
description="The maximum number of tokens to generate.",
|
||||||
|
gt=0,
|
||||||
|
)
|
||||||
|
system_prompt: str = Field(
|
||||||
|
default="",
|
||||||
|
description=(
|
||||||
|
"The system prompt, containing any extra instructions or context. "
|
||||||
|
"The model card on HuggingFace should specify if this is needed."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
query_wrapper_prompt: PromptTemplate = Field(
|
||||||
|
default=PromptTemplate("{query_str}"),
|
||||||
|
description=(
|
||||||
|
"The query wrapper prompt, containing the query placeholder. "
|
||||||
|
"The model card on HuggingFace should specify if this is needed. "
|
||||||
|
"Should contain a `{query_str}` placeholder."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
tokenizer_name: str = Field(
|
||||||
|
default=DEFAULT_HUGGINGFACE_MODEL,
|
||||||
|
description=(
|
||||||
|
"The name of the tokenizer to use from HuggingFace. "
|
||||||
|
"Unused if `tokenizer` is passed in directly."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
device_map: str = Field(
|
||||||
|
default="auto", description="The device_map to use. Defaults to 'auto'."
|
||||||
|
)
|
||||||
|
stopping_ids: List[int] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description=(
|
||||||
|
"The stopping ids to use. "
|
||||||
|
"Generation stops when these token IDs are predicted."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
tokenizer_outputs_to_remove: list = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description=(
|
||||||
|
"The outputs to remove from the tokenizer. "
|
||||||
|
"Sometimes huggingface tokenizers return extra inputs that cause errors."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
tokenizer_kwargs: dict = Field(
|
||||||
|
default_factory=dict, description="The kwargs to pass to the tokenizer."
|
||||||
|
)
|
||||||
|
model_kwargs: dict = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="The kwargs to pass to the model during initialization.",
|
||||||
|
)
|
||||||
|
generate_kwargs: dict = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="The kwargs to pass to the model during generation.",
|
||||||
|
)
|
||||||
|
is_chat_model: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
LLMMetadata.__fields__["is_chat_model"].field_info.description
|
||||||
|
+ " Be sure to verify that you either pass an appropriate tokenizer "
|
||||||
|
"that can convert prompts to properly formatted chat messages or a "
|
||||||
|
"`messages_to_prompt` that does so."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
_model: Any = PrivateAttr()
|
||||||
|
_tokenizer: Any = PrivateAttr()
|
||||||
|
_stopping_criteria: Any = PrivateAttr()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context_window: int = DEFAULT_CONTEXT_WINDOW,
|
||||||
|
max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
|
||||||
|
query_wrapper_prompt: Union[str, PromptTemplate]="{query_str}",
|
||||||
|
tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL,
|
||||||
|
model_name: str = DEFAULT_HUGGINGFACE_MODEL,
|
||||||
|
model: Optional[Any] = None,
|
||||||
|
tokenizer: Optional[Any] = None,
|
||||||
|
device_map: Optional[str] = "auto",
|
||||||
|
stopping_ids: Optional[List[int]] = None,
|
||||||
|
tokenizer_kwargs: Optional[dict] = None,
|
||||||
|
tokenizer_outputs_to_remove: Optional[list] = None,
|
||||||
|
model_kwargs: Optional[dict] = None,
|
||||||
|
generate_kwargs: Optional[dict] = None,
|
||||||
|
is_chat_model: Optional[bool] = False,
|
||||||
|
callback_manager: Optional[CallbackManager] = None,
|
||||||
|
system_prompt: str = "",
|
||||||
|
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]]=None,
|
||||||
|
completion_to_prompt: Optional[Callable[[str], str]]=None,
|
||||||
|
pydantic_program_mode: PydanticProgramMode=PydanticProgramMode.DEFAULT,
|
||||||
|
output_parser: Optional[BaseOutputParser] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Construct BigdlLLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
context_window: The maximum number of tokens available for input.
|
||||||
|
max_new_tokens: The maximum number of tokens to generate.
|
||||||
|
query_wrapper_prompt: The query wrapper prompt, containing the query placeholder.
|
||||||
|
Should contain a `{query_str}` placeholder.
|
||||||
|
tokenizer_name: The name of the tokenizer to use from HuggingFace.
|
||||||
|
Unused if `tokenizer` is passed in directly.
|
||||||
|
model_name: The model name to use from HuggingFace.
|
||||||
|
Unused if `model` is passed in directly.
|
||||||
|
model: The HuggingFace model.
|
||||||
|
tokenizer: The tokenizer.
|
||||||
|
device_map: The device_map to use. Defaults to 'auto'.
|
||||||
|
stopping_ids: The stopping ids to use.
|
||||||
|
Generation stops when these token IDs are predicted.
|
||||||
|
tokenizer_kwargs: The kwargs to pass to the tokenizer.
|
||||||
|
tokenizer_outputs_to_remove: The outputs to remove from the tokenizer.
|
||||||
|
Sometimes huggingface tokenizers return extra inputs that cause errors.
|
||||||
|
model_kwargs: The kwargs to pass to the model during initialization.
|
||||||
|
generate_kwargs: The kwargs to pass to the model during generation.
|
||||||
|
is_chat_model: Whether the model is `chat`
|
||||||
|
callback_manager: Callback manager.
|
||||||
|
system_prompt: The system prompt, containing any extra instructions or context.
|
||||||
|
messages_to_prompt: Function to convert messages to prompt.
|
||||||
|
completion_to_prompt: Function to convert messages to prompt.
|
||||||
|
pydantic_program_mode: DEFAULT.
|
||||||
|
output_parser: BaseOutputParser.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None.
|
||||||
|
"""
|
||||||
|
model_kwargs = model_kwargs or {}
|
||||||
|
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||||
|
self._model = model or AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name, load_in_4bit=True, **model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# check context_window
|
||||||
|
config_dict = self._model.config.to_dict()
|
||||||
|
model_context_window = int(
|
||||||
|
config_dict.get("max_position_embeddings", context_window)
|
||||||
|
)
|
||||||
|
if model_context_window and model_context_window < context_window:
|
||||||
|
logger.warning(
|
||||||
|
f"Supplied context_window {context_window} is greater "
|
||||||
|
f"than the model's max input size {model_context_window}. "
|
||||||
|
"Disable this warning by setting a lower context_window."
|
||||||
|
)
|
||||||
|
context_window = model_context_window
|
||||||
|
|
||||||
|
tokenizer_kwargs = tokenizer_kwargs or {}
|
||||||
|
if "max_length" not in tokenizer_kwargs:
|
||||||
|
tokenizer_kwargs["max_length"] = context_window
|
||||||
|
|
||||||
|
self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(
|
||||||
|
tokenizer_name, **tokenizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if tokenizer_name != model_name:
|
||||||
|
logger.warning(
|
||||||
|
f"The model `{model_name}` and tokenizer `{tokenizer_name}` "
|
||||||
|
f"are different, please ensure that they are compatible."
|
||||||
|
)
|
||||||
|
|
||||||
|
# setup stopping criteria
|
||||||
|
stopping_ids_list = stopping_ids or []
|
||||||
|
|
||||||
|
class StopOnTokens(StoppingCriteria):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
scores: torch.FloatTensor,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> bool:
|
||||||
|
for stop_id in stopping_ids_list:
|
||||||
|
if input_ids[0][-1] == stop_id:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._stopping_criteria = StoppingCriteriaList([StopOnTokens()])
|
||||||
|
|
||||||
|
if isinstance(query_wrapper_prompt, str):
|
||||||
|
query_wrapper_prompt = PromptTemplate(query_wrapper_prompt)
|
||||||
|
|
||||||
|
messages_to_prompt = messages_to_prompt or self._tokenizer_messages_to_prompt
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
context_window=context_window,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
query_wrapper_prompt=query_wrapper_prompt,
|
||||||
|
tokenizer_name=tokenizer_name,
|
||||||
|
model_name=model_name,
|
||||||
|
device_map=device_map,
|
||||||
|
stopping_ids=stopping_ids or [],
|
||||||
|
tokenizer_kwargs=tokenizer_kwargs or {},
|
||||||
|
tokenizer_outputs_to_remove=tokenizer_outputs_to_remove or [],
|
||||||
|
model_kwargs=model_kwargs or {},
|
||||||
|
generate_kwargs=generate_kwargs or {},
|
||||||
|
is_chat_model=is_chat_model,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
messages_to_prompt=messages_to_prompt,
|
||||||
|
completion_to_prompt=completion_to_prompt,
|
||||||
|
pydantic_program_mode=pydantic_program_mode,
|
||||||
|
output_parser=output_parser,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
"""
|
||||||
|
Get class name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Str of class name.
|
||||||
|
"""
|
||||||
|
return "BigDL_LLM"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata(self) -> LLMMetadata:
|
||||||
|
"""
|
||||||
|
Get meta data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMmetadata contains context_window,
|
||||||
|
num_output, model_name, is_chat_model.
|
||||||
|
"""
|
||||||
|
return LLMMetadata(
|
||||||
|
context_window=self.context_window,
|
||||||
|
num_output=self.max_new_tokens,
|
||||||
|
model_name=self.model_name,
|
||||||
|
is_chat_model=self.is_chat_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _tokenizer_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||||
|
"""
|
||||||
|
Use the tokenizer to convert messages to prompt. Fallback to generic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Sequence of ChatMessage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Str of response.
|
||||||
|
"""
|
||||||
|
if hasattr(self._tokenizer, "apply_chat_template"):
|
||||||
|
messages_dict = [
|
||||||
|
{"role": message.role.value, "content": message.content}
|
||||||
|
for message in messages
|
||||||
|
]
|
||||||
|
tokens = self._tokenizer.apply_chat_template(messages_dict)
|
||||||
|
return self._tokenizer.decode(tokens)
|
||||||
|
|
||||||
|
return generic_messages_to_prompt(messages)
|
||||||
|
|
||||||
|
@llm_completion_callback()
|
||||||
|
def complete(
|
||||||
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
||||||
|
) -> CompletionResponse:
|
||||||
|
"""
|
||||||
|
Complete by LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Prompt for completion.
|
||||||
|
formatted: Whether the prompt is formatted by wrapper.
|
||||||
|
kwargs: Other kwargs for complete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CompletionReponse after generation.
|
||||||
|
"""
|
||||||
|
full_prompt = prompt
|
||||||
|
if not formatted:
|
||||||
|
if self.query_wrapper_prompt:
|
||||||
|
full_prompt = self.query_wrapper_prompt.format(query_str=prompt)
|
||||||
|
if self.system_prompt:
|
||||||
|
full_prompt = f"{self.system_prompt} {full_prompt}"
|
||||||
|
input_ids = self._tokenizer(full_prompt, return_tensors="pt")
|
||||||
|
input_ids = input_ids.to(self._model.device)
|
||||||
|
# remove keys from the tokenizer if needed, to avoid HF errors
|
||||||
|
for key in self.tokenizer_outputs_to_remove:
|
||||||
|
if key in input_ids:
|
||||||
|
input_ids.pop(key, None)
|
||||||
|
tokens = self._model.generate(
|
||||||
|
**input_ids,
|
||||||
|
max_new_tokens=self.max_new_tokens,
|
||||||
|
stopping_criteria=self._stopping_criteria,
|
||||||
|
**self.generate_kwargs,
|
||||||
|
)
|
||||||
|
completion_tokens = tokens[0][input_ids["input_ids"].size(1):]
|
||||||
|
completion = self._tokenizer.decode(completion_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
return CompletionResponse(text=completion, raw={"model_output": tokens})
|
||||||
|
|
||||||
|
@llm_completion_callback()
|
||||||
|
def stream_complete(
|
||||||
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
||||||
|
) -> CompletionResponseGen:
|
||||||
|
"""
|
||||||
|
Complete by LLM in stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Prompt for completion.
|
||||||
|
formatted: Whether the prompt is formatted by wrapper.
|
||||||
|
kwargs: Other kwargs for complete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CompletionReponse after generation.
|
||||||
|
"""
|
||||||
|
from transformers import TextStreamer
|
||||||
|
full_prompt = prompt
|
||||||
|
if not formatted:
|
||||||
|
if self.query_wrapper_prompt:
|
||||||
|
full_prompt = self.query_wrapper_prompt.format(query_str=prompt)
|
||||||
|
if self.system_prompt:
|
||||||
|
full_prompt = f"{self.system_prompt} {full_prompt}"
|
||||||
|
|
||||||
|
input_ids = self._tokenizer.encode(full_prompt, return_tensors="pt")
|
||||||
|
input_ids = input_ids.to(self._model.device)
|
||||||
|
|
||||||
|
for key in self.tokenizer_outputs_to_remove:
|
||||||
|
if key in input_ids:
|
||||||
|
input_ids.pop(key, None)
|
||||||
|
|
||||||
|
streamer = TextStreamer(self._tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
generation_kwargs = dict(
|
||||||
|
input_ids,
|
||||||
|
streamer=streamer,
|
||||||
|
max_new_tokens=self.max_new_tokens,
|
||||||
|
stopping_criteria=self._stopping_criteria,
|
||||||
|
**self.generate_kwargs,
|
||||||
|
)
|
||||||
|
thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# create generator based off of streamer
|
||||||
|
def gen() -> CompletionResponseGen:
|
||||||
|
text = ""
|
||||||
|
for x in streamer:
|
||||||
|
text += x
|
||||||
|
yield CompletionResponse(text=text, delta=x)
|
||||||
|
|
||||||
|
return gen()
|
||||||
88
python/llm/test/llamaindex/test_llamaindex.py
Normal file
88
python/llm/test/llamaindex/test_llamaindex.py
Normal file
|
|
@ -0,0 +1,88 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from bigdl.llm.langchain.llms import TransformersLLM, TransformersPipelineLLM, \
|
||||||
|
LlamaLLM, BloomLLM
|
||||||
|
from bigdl.llm.langchain.embeddings import TransformersEmbeddings, LlamaEmbeddings, \
|
||||||
|
BloomEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
from langchain.document_loaders import WebBaseLoader
|
||||||
|
from langchain.indexes import VectorstoreIndexCreator
|
||||||
|
|
||||||
|
|
||||||
|
from langchain.chains.question_answering import load_qa_chain
|
||||||
|
from langchain.chains.chat_vector_db.prompts import (CONDENSE_QUESTION_PROMPT,
|
||||||
|
QA_PROMPT)
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
from langchain.vectorstores import Chroma
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest import TestCase
|
||||||
|
import os
|
||||||
|
from bigdl.llm.llamaindex.llms import BigdlLLM
|
||||||
|
|
||||||
|
class Test_LlamaIndex_Transformers_API(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.auto_model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
|
||||||
|
self.auto_causal_model_path = os.environ.get('ORIGINAL_REPLIT_CODE_PATH')
|
||||||
|
self.llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH')
|
||||||
|
self.bloom_model_path = os.environ.get('BLOOM_ORIGIN_PATH')
|
||||||
|
thread_num = os.environ.get('THREAD_NUM')
|
||||||
|
if thread_num is not None:
|
||||||
|
self.n_threads = int(thread_num)
|
||||||
|
else:
|
||||||
|
self.n_threads = 2
|
||||||
|
|
||||||
|
def completion_to_prompt(completion):
|
||||||
|
return f"<|system|>\n</s>\n<|user|>\n{completion}</s>\n<|assistant|>\n"
|
||||||
|
|
||||||
|
def messages_to_prompt(messages):
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
if message.role == "system":
|
||||||
|
prompt += f"<|system|>\n{message.content}</s>\n"
|
||||||
|
elif message.role == "user":
|
||||||
|
prompt += f"<|user|>\n{message.content}</s>\n"
|
||||||
|
elif message.role == "assistant":
|
||||||
|
prompt += f"<|assistant|>\n{message.content}</s>\n"
|
||||||
|
|
||||||
|
# ensure we start with a system prompt, insert blank if needed
|
||||||
|
if not prompt.startswith("<|system|>\n"):
|
||||||
|
prompt = "<|system|>\n</s>\n" + prompt
|
||||||
|
|
||||||
|
# add final assistant prompt
|
||||||
|
prompt = prompt + "<|assistant|>\n"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def test_bigdl_llm(self):
|
||||||
|
llm = BigdlLLM(
|
||||||
|
model_name=self.llama_model_path,
|
||||||
|
tokenizer_name=self.llama_model_path,
|
||||||
|
context_window=512,
|
||||||
|
max_new_tokens=32,
|
||||||
|
model_kwargs={},
|
||||||
|
generate_kwargs={"temperature": 0.7, "do_sample": False},
|
||||||
|
messages_to_prompt=self.messages_to_prompt,
|
||||||
|
completion_to_prompt=self.completion_to_prompt,
|
||||||
|
device_map="cpu",
|
||||||
|
)
|
||||||
|
res = llm.complete("What is AI?")
|
||||||
|
assert res!=None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__])
|
||||||
Loading…
Reference in a new issue