Add LlamaIndex RAG (#10263)

* run demo

* format code

* add llamaindex

* add custom LLM with bigdl

* update

* add readme

* begin ut

* add unit test

* add license

* add license

* revised

* update

* modify docs

* remove data folder

* update

* modify prompt

* fixed

* fixed

* fixed
This commit is contained in:
Zhicun 2024-02-29 15:21:19 +08:00 committed by GitHub
parent 5d7243067c
commit 4e6cc424f1
6 changed files with 898 additions and 0 deletions

View file

@ -0,0 +1,60 @@
# LlamaIndex Examples
The examples here show how to use LlamaIndex with `bigdl-llm`.
The RAG example is modified from the [demo](https://docs.llamaindex.ai/en/stable/examples/low_level/oss_ingestion_retrieval.html).
## Install bigdl-llm
Follow the instructions in [Install](https://github.com/intel-analytics/BigDL/tree/main/python/llm#install).
## Install Required Dependencies for llamaindex examples.
### Install Site-packages
```bash
pip install llama-index-readers-file
pip install llama-index-vector-stores-postgres
pip install llama-index-embeddings-huggingface
```
### Install Postgres
> Note: There are plenty of open-source databases you can use. Here we provide an example using Postgres.
* Download and install postgres by running the commands below.
```bash
sudo apt-get install postgresql-client
sudo apt-get install postgresql
```
* Initilize postgres.
```bash
sudo su - postgres
psql
```
After running the commands in the shell, we reach the console of postgres. Then we can add a role like the following
```bash
CREATE ROLE <user> WITH LOGIN PASSWORD '<password>';
ALTER ROLE <user> SUPERUSER;
```
* Install pgvector according to the [page](https://github.com/pgvector/pgvector). If you encounter problem about the installation, please refer to the [notes](https://github.com/pgvector/pgvector#installation-notes) which may be helpful.
* Download the database.
```bash
mkdir data
wget --user-agent "Mozilla" "https://arxiv.org/pdf/2307.09288.pdf" -O "data/llama2.pdf"
```
## Run the examples
### Retrieval-augmented Generation
```bash
python rag.py -m MODEL_PATH -e EMBEDDING_MODEL_PATH -u USERNAME -p PASSWORD -q QUESTION -d DATA
```
arguments info:
- `-m MODEL_PATH`: **required**, path to the llama model
- `-e EMBEDDING_MODEL_PATH`: path to the embedding model
- `-u USERNAME`: username in the postgres database
- `-p PASSWORD`: password in the postgres database
- `-q QUESTION`: question you want to ask
- `-d DATA`: path to data used during retrieval
Here is the sample output when applying Llama-2-7b-chat-hf as the generatio model when we ask "How does Llama 2 perform compared to other open-source models?" and use llama.pdf as database.
```
Llama 2 performs better than most open-source models on the benchmarks we tested. Specifically, it outperforms all open-source models on MMLU and BBH, and is close to GPT-3.5 on these benchmarks. Additionally, Llama 2 is on par or better than PaLM-2-L on almost all benchmarks. The only exception is the coding benchmarks, where Llama 2 lags significantly behind GPT-4 and PaLM-2-L. Overall, Llama 2 demonstrates strong performance on a wide range of natural language processing tasks.
```

View file

@ -0,0 +1,247 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from sqlalchemy import make_url
from llama_index.vector_stores.postgres import PGVectorStore
# from llama_index.llms.llama_cpp import LlamaCPP
import psycopg2
from pathlib import Path
from llama_index.readers.file import PyMuPDFReader
from llama_index.core.schema import NodeWithScore
from typing import Optional
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core import QueryBundle
from llama_index.core.retrievers import BaseRetriever
from typing import Any, List
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.vector_stores import VectorStoreQuery
import argparse
def load_vector_database(username, password):
db_name = "example_db"
host = "localhost"
password = password
port = "5432"
user = username
# conn = psycopg2.connect(connection_string)
conn = psycopg2.connect(
dbname="postgres",
host=host,
password=password,
port=port,
user=user,
)
conn.autocommit = True
with conn.cursor() as c:
c.execute(f"DROP DATABASE IF EXISTS {db_name}")
c.execute(f"CREATE DATABASE {db_name}")
vector_store = PGVectorStore.from_params(
database=db_name,
host=host,
password=password,
port=port,
user=user,
table_name="llama2_paper",
embed_dim=384, # openai embedding dimension
)
return vector_store
def load_data(data_path):
loader = PyMuPDFReader()
documents = loader.load(file_path=data_path)
text_parser = SentenceSplitter(
chunk_size=1024,
# separator=" ",
)
text_chunks = []
# maintain relationship with source doc index, to help inject doc metadata in (3)
doc_idxs = []
for doc_idx, doc in enumerate(documents):
cur_text_chunks = text_parser.split_text(doc.text)
text_chunks.extend(cur_text_chunks)
doc_idxs.extend([doc_idx] * len(cur_text_chunks))
from llama_index.core.schema import TextNode
nodes = []
for idx, text_chunk in enumerate(text_chunks):
node = TextNode(
text=text_chunk,
)
src_doc = documents[doc_idxs[idx]]
node.metadata = src_doc.metadata
nodes.append(node)
return nodes
class VectorDBRetriever(BaseRetriever):
"""Retriever over a postgres vector store."""
def __init__(
self,
vector_store: PGVectorStore,
embed_model: Any,
query_mode: str = "default",
similarity_top_k: int = 2,
) -> None:
"""Init params."""
self._vector_store = vector_store
self._embed_model = embed_model
self._query_mode = query_mode
self._similarity_top_k = similarity_top_k
super().__init__()
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve."""
query_embedding = self._embed_model.get_query_embedding(
query_bundle.query_str
)
vector_store_query = VectorStoreQuery(
query_embedding=query_embedding,
similarity_top_k=self._similarity_top_k,
mode=self._query_mode,
)
query_result = self._vector_store.query(vector_store_query)
nodes_with_scores = []
for index, node in enumerate(query_result.nodes):
score: Optional[float] = None
if query_result.similarities is not None:
score = query_result.similarities[index]
nodes_with_scores.append(NodeWithScore(node=node, score=score))
return nodes_with_scores
def completion_to_prompt(completion):
return f"<|system|>\n</s>\n<|user|>\n{completion}</s>\n<|assistant|>\n"
# Transform a list of chat messages into zephyr-specific input
def messages_to_prompt(messages):
prompt = ""
for message in messages:
if message.role == "system":
prompt += f"<|system|>\n{message.content}</s>\n"
elif message.role == "user":
prompt += f"<|user|>\n{message.content}</s>\n"
elif message.role == "assistant":
prompt += f"<|assistant|>\n{message.content}</s>\n"
# ensure we start with a system prompt, insert blank if needed
if not prompt.startswith("<|system|>\n"):
prompt = "<|system|>\n</s>\n" + prompt
# add final assistant prompt
prompt = prompt + "<|assistant|>\n"
return prompt
def main(args):
embed_model = HuggingFaceEmbedding(model_name=args.embedding_model_path)
# Use custom LLM in BigDL
from bigdl.llm.llamaindex.llms import BigdlLLM
llm = BigdlLLM(
model_name=args.model_path,
tokenizer_name=args.model_path,
context_window=512,
max_new_tokens=32,
generate_kwargs={"temperature": 0.7, "do_sample": False},
model_kwargs={},
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
device_map="cpu",
)
vector_store = load_vector_database(username=args.user, password=args.password)
nodes = load_data(data_path=args.data)
for node in nodes:
node_embedding = embed_model.get_text_embedding(
node.get_content(metadata_mode="all")
)
node.embedding = node_embedding
vector_store.add(nodes)
# query_str = "Can you tell me about the key concepts for safety finetuning"
query_str = "Explain about the training data for Llama 2"
query_embedding = embed_model.get_query_embedding(query_str)
# construct vector store query
query_mode = "default"
# query_mode = "sparse"
# query_mode = "hybrid"
vector_store_query = VectorStoreQuery(
query_embedding=query_embedding, similarity_top_k=2, mode=query_mode
)
# returns a VectorStoreQueryResult
query_result = vector_store.query(vector_store_query)
# print("Retrieval Results: ")
# print(query_result.nodes[0].get_content())
nodes_with_scores = []
for index, node in enumerate(query_result.nodes):
score: Optional[float] = None
if query_result.similarities is not None:
score = query_result.similarities[index]
nodes_with_scores.append(NodeWithScore(node=node, score=score))
retriever = VectorDBRetriever(
vector_store, embed_model, query_mode="default", similarity_top_k=1
)
query_engine = RetrieverQueryEngine.from_args(retriever, llm=llm)
# query_str = "How does Llama 2 perform compared to other open-source models?"
query_str = args.question
response = query_engine.query(query_str)
print("------------RESPONSE GENERATION---------------------")
print(str(response))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='LlamaIndex BigdlLLM Example')
parser.add_argument('-m','--model-path', type=str, required=True,
help='the path to transformers model')
parser.add_argument('-q', '--question', type=str, default='How does Llama 2 perform compared to other open-source models?',
help='qustion you want to ask.')
parser.add_argument('-d','--data',type=str, default='./data/llama2.pdf',
help="the data used during retrieval")
parser.add_argument('-u', '--user', type=str, required=True,
help="user name in the database postgres")
parser.add_argument('-p','--password', type=str, required=True,
help="the password of the user in the database")
parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en",
help="the path to embedding model path")
args = parser.parse_args()
main(args)

View file

@ -0,0 +1,20 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This would makes sure Python is aware there is more than one sub-package within bigdl,
# physically located elsewhere.
# Otherwise there would be module not found error in non-pip's setting as Python would
# only search the first bigdl package and end up finding only one sub-package.

View file

@ -0,0 +1,34 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This would makes sure Python is aware there is more than one sub-package within bigdl,
# physically located elsewhere.
# Otherwise there would be module not found error in non-pip's setting as Python would
# only search the first bigdl package and end up finding only one sub-package.
"""Wrappers on top of large language models APIs."""
from typing import Dict, Type
from .bigdlllm import *
from llama_index.core.base.llms.base import BaseLLM
__all__ = [
"BigdlLLM",
]
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"BigdlLLM": BigdlLLM,
}

View file

@ -0,0 +1,449 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# The file is modified from
# https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/base/llms/base.py
# The MIT License
# Copyright (c) Harrison Chase
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import logging
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import torch
from huggingface_hub import AsyncInferenceClient, InferenceClient, model_info
from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponse,
CompletionResponseAsyncGen,
CompletionResponseGen,
LLMMetadata,
MessageRole,
)
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CallbackManager
from llama_index.core.constants import (
DEFAULT_CONTEXT_WINDOW,
DEFAULT_NUM_OUTPUTS,
)
from llama_index.core.llms.callbacks import (
llm_chat_callback,
llm_completion_callback,
)
from llama_index.core.llms.custom import CustomLLM
from llama_index.core.base.llms.generic_utils import (
messages_to_prompt as generic_messages_to_prompt,
)
from llama_index.core.prompts.base import PromptTemplate
from llama_index.core.types import BaseOutputParser, PydanticProgramMode
from transformers import (
StoppingCriteria,
StoppingCriteriaList,
)
from transformers import AutoTokenizer, LlamaTokenizer
DEFAULT_HUGGINGFACE_MODEL = "meta-llama/Llama-2-7b-chat-hf"
logger = logging.getLogger(__name__)
class BigdlLLM(CustomLLM):
"""Wrapper around the BigDL-LLM
Example:
.. code-block:: python
from bigdl.llm.llamaindex.llms import BigdlLLM
llm = BigdlLLM(model_path="/path/to/llama/model")
"""
model_name: str = Field(
default=DEFAULT_HUGGINGFACE_MODEL,
description=(
"The model name to use from HuggingFace. "
"Unused if `model` is passed in directly."
),
)
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description="The maximum number of tokens available for input.",
gt=0,
)
max_new_tokens: int = Field(
default=DEFAULT_NUM_OUTPUTS,
description="The maximum number of tokens to generate.",
gt=0,
)
system_prompt: str = Field(
default="",
description=(
"The system prompt, containing any extra instructions or context. "
"The model card on HuggingFace should specify if this is needed."
),
)
query_wrapper_prompt: PromptTemplate = Field(
default=PromptTemplate("{query_str}"),
description=(
"The query wrapper prompt, containing the query placeholder. "
"The model card on HuggingFace should specify if this is needed. "
"Should contain a `{query_str}` placeholder."
),
)
tokenizer_name: str = Field(
default=DEFAULT_HUGGINGFACE_MODEL,
description=(
"The name of the tokenizer to use from HuggingFace. "
"Unused if `tokenizer` is passed in directly."
),
)
device_map: str = Field(
default="auto", description="The device_map to use. Defaults to 'auto'."
)
stopping_ids: List[int] = Field(
default_factory=list,
description=(
"The stopping ids to use. "
"Generation stops when these token IDs are predicted."
),
)
tokenizer_outputs_to_remove: list = Field(
default_factory=list,
description=(
"The outputs to remove from the tokenizer. "
"Sometimes huggingface tokenizers return extra inputs that cause errors."
),
)
tokenizer_kwargs: dict = Field(
default_factory=dict, description="The kwargs to pass to the tokenizer."
)
model_kwargs: dict = Field(
default_factory=dict,
description="The kwargs to pass to the model during initialization.",
)
generate_kwargs: dict = Field(
default_factory=dict,
description="The kwargs to pass to the model during generation.",
)
is_chat_model: bool = Field(
default=False,
description=(
LLMMetadata.__fields__["is_chat_model"].field_info.description
+ " Be sure to verify that you either pass an appropriate tokenizer "
"that can convert prompts to properly formatted chat messages or a "
"`messages_to_prompt` that does so."
),
)
_model: Any = PrivateAttr()
_tokenizer: Any = PrivateAttr()
_stopping_criteria: Any = PrivateAttr()
def __init__(
self,
context_window: int = DEFAULT_CONTEXT_WINDOW,
max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
query_wrapper_prompt: Union[str, PromptTemplate]="{query_str}",
tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL,
model_name: str = DEFAULT_HUGGINGFACE_MODEL,
model: Optional[Any] = None,
tokenizer: Optional[Any] = None,
device_map: Optional[str] = "auto",
stopping_ids: Optional[List[int]] = None,
tokenizer_kwargs: Optional[dict] = None,
tokenizer_outputs_to_remove: Optional[list] = None,
model_kwargs: Optional[dict] = None,
generate_kwargs: Optional[dict] = None,
is_chat_model: Optional[bool] = False,
callback_manager: Optional[CallbackManager] = None,
system_prompt: str = "",
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]]=None,
completion_to_prompt: Optional[Callable[[str], str]]=None,
pydantic_program_mode: PydanticProgramMode=PydanticProgramMode.DEFAULT,
output_parser: Optional[BaseOutputParser] = None,
) -> None:
"""
Construct BigdlLLM.
Args:
context_window: The maximum number of tokens available for input.
max_new_tokens: The maximum number of tokens to generate.
query_wrapper_prompt: The query wrapper prompt, containing the query placeholder.
Should contain a `{query_str}` placeholder.
tokenizer_name: The name of the tokenizer to use from HuggingFace.
Unused if `tokenizer` is passed in directly.
model_name: The model name to use from HuggingFace.
Unused if `model` is passed in directly.
model: The HuggingFace model.
tokenizer: The tokenizer.
device_map: The device_map to use. Defaults to 'auto'.
stopping_ids: The stopping ids to use.
Generation stops when these token IDs are predicted.
tokenizer_kwargs: The kwargs to pass to the tokenizer.
tokenizer_outputs_to_remove: The outputs to remove from the tokenizer.
Sometimes huggingface tokenizers return extra inputs that cause errors.
model_kwargs: The kwargs to pass to the model during initialization.
generate_kwargs: The kwargs to pass to the model during generation.
is_chat_model: Whether the model is `chat`
callback_manager: Callback manager.
system_prompt: The system prompt, containing any extra instructions or context.
messages_to_prompt: Function to convert messages to prompt.
completion_to_prompt: Function to convert messages to prompt.
pydantic_program_mode: DEFAULT.
output_parser: BaseOutputParser.
Returns:
None.
"""
model_kwargs = model_kwargs or {}
from bigdl.llm.transformers import AutoModelForCausalLM
self._model = model or AutoModelForCausalLM.from_pretrained(
model_name, load_in_4bit=True, **model_kwargs
)
# check context_window
config_dict = self._model.config.to_dict()
model_context_window = int(
config_dict.get("max_position_embeddings", context_window)
)
if model_context_window and model_context_window < context_window:
logger.warning(
f"Supplied context_window {context_window} is greater "
f"than the model's max input size {model_context_window}. "
"Disable this warning by setting a lower context_window."
)
context_window = model_context_window
tokenizer_kwargs = tokenizer_kwargs or {}
if "max_length" not in tokenizer_kwargs:
tokenizer_kwargs["max_length"] = context_window
self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(
tokenizer_name, **tokenizer_kwargs
)
if tokenizer_name != model_name:
logger.warning(
f"The model `{model_name}` and tokenizer `{tokenizer_name}` "
f"are different, please ensure that they are compatible."
)
# setup stopping criteria
stopping_ids_list = stopping_ids or []
class StopOnTokens(StoppingCriteria):
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs: Any,
) -> bool:
for stop_id in stopping_ids_list:
if input_ids[0][-1] == stop_id:
return True
return False
self._stopping_criteria = StoppingCriteriaList([StopOnTokens()])
if isinstance(query_wrapper_prompt, str):
query_wrapper_prompt = PromptTemplate(query_wrapper_prompt)
messages_to_prompt = messages_to_prompt or self._tokenizer_messages_to_prompt
super().__init__(
context_window=context_window,
max_new_tokens=max_new_tokens,
query_wrapper_prompt=query_wrapper_prompt,
tokenizer_name=tokenizer_name,
model_name=model_name,
device_map=device_map,
stopping_ids=stopping_ids or [],
tokenizer_kwargs=tokenizer_kwargs or {},
tokenizer_outputs_to_remove=tokenizer_outputs_to_remove or [],
model_kwargs=model_kwargs or {},
generate_kwargs=generate_kwargs or {},
is_chat_model=is_chat_model,
callback_manager=callback_manager,
system_prompt=system_prompt,
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
pydantic_program_mode=pydantic_program_mode,
output_parser=output_parser,
)
@classmethod
def class_name(cls) -> str:
"""
Get class name.
Args:
Returns:
Str of class name.
"""
return "BigDL_LLM"
@property
def metadata(self) -> LLMMetadata:
"""
Get meta data.
Args:
Returns:
LLMmetadata contains context_window,
num_output, model_name, is_chat_model.
"""
return LLMMetadata(
context_window=self.context_window,
num_output=self.max_new_tokens,
model_name=self.model_name,
is_chat_model=self.is_chat_model,
)
def _tokenizer_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
"""
Use the tokenizer to convert messages to prompt. Fallback to generic.
Args:
messages: Sequence of ChatMessage.
Returns:
Str of response.
"""
if hasattr(self._tokenizer, "apply_chat_template"):
messages_dict = [
{"role": message.role.value, "content": message.content}
for message in messages
]
tokens = self._tokenizer.apply_chat_template(messages_dict)
return self._tokenizer.decode(tokens)
return generic_messages_to_prompt(messages)
@llm_completion_callback()
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
"""
Complete by LLM.
Args:
prompt: Prompt for completion.
formatted: Whether the prompt is formatted by wrapper.
kwargs: Other kwargs for complete.
Returns:
CompletionReponse after generation.
"""
full_prompt = prompt
if not formatted:
if self.query_wrapper_prompt:
full_prompt = self.query_wrapper_prompt.format(query_str=prompt)
if self.system_prompt:
full_prompt = f"{self.system_prompt} {full_prompt}"
input_ids = self._tokenizer(full_prompt, return_tensors="pt")
input_ids = input_ids.to(self._model.device)
# remove keys from the tokenizer if needed, to avoid HF errors
for key in self.tokenizer_outputs_to_remove:
if key in input_ids:
input_ids.pop(key, None)
tokens = self._model.generate(
**input_ids,
max_new_tokens=self.max_new_tokens,
stopping_criteria=self._stopping_criteria,
**self.generate_kwargs,
)
completion_tokens = tokens[0][input_ids["input_ids"].size(1):]
completion = self._tokenizer.decode(completion_tokens, skip_special_tokens=True)
return CompletionResponse(text=completion, raw={"model_output": tokens})
@llm_completion_callback()
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
"""
Complete by LLM in stream.
Args:
prompt: Prompt for completion.
formatted: Whether the prompt is formatted by wrapper.
kwargs: Other kwargs for complete.
Returns:
CompletionReponse after generation.
"""
from transformers import TextStreamer
full_prompt = prompt
if not formatted:
if self.query_wrapper_prompt:
full_prompt = self.query_wrapper_prompt.format(query_str=prompt)
if self.system_prompt:
full_prompt = f"{self.system_prompt} {full_prompt}"
input_ids = self._tokenizer.encode(full_prompt, return_tensors="pt")
input_ids = input_ids.to(self._model.device)
for key in self.tokenizer_outputs_to_remove:
if key in input_ids:
input_ids.pop(key, None)
streamer = TextStreamer(self._tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
input_ids,
streamer=streamer,
max_new_tokens=self.max_new_tokens,
stopping_criteria=self._stopping_criteria,
**self.generate_kwargs,
)
thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
thread.start()
# create generator based off of streamer
def gen() -> CompletionResponseGen:
text = ""
for x in streamer:
text += x
yield CompletionResponse(text=text, delta=x)
return gen()

View file

@ -0,0 +1,88 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from bigdl.llm.langchain.llms import TransformersLLM, TransformersPipelineLLM, \
LlamaLLM, BloomLLM
from bigdl.llm.langchain.embeddings import TransformersEmbeddings, LlamaEmbeddings, \
BloomEmbeddings
from langchain.document_loaders import WebBaseLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.chat_vector_db.prompts import (CONDENSE_QUESTION_PROMPT,
QA_PROMPT)
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
import pytest
from unittest import TestCase
import os
from bigdl.llm.llamaindex.llms import BigdlLLM
class Test_LlamaIndex_Transformers_API(TestCase):
def setUp(self):
self.auto_model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
self.auto_causal_model_path = os.environ.get('ORIGINAL_REPLIT_CODE_PATH')
self.llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH')
self.bloom_model_path = os.environ.get('BLOOM_ORIGIN_PATH')
thread_num = os.environ.get('THREAD_NUM')
if thread_num is not None:
self.n_threads = int(thread_num)
else:
self.n_threads = 2
def completion_to_prompt(completion):
return f"<|system|>\n</s>\n<|user|>\n{completion}</s>\n<|assistant|>\n"
def messages_to_prompt(messages):
prompt = ""
for message in messages:
if message.role == "system":
prompt += f"<|system|>\n{message.content}</s>\n"
elif message.role == "user":
prompt += f"<|user|>\n{message.content}</s>\n"
elif message.role == "assistant":
prompt += f"<|assistant|>\n{message.content}</s>\n"
# ensure we start with a system prompt, insert blank if needed
if not prompt.startswith("<|system|>\n"):
prompt = "<|system|>\n</s>\n" + prompt
# add final assistant prompt
prompt = prompt + "<|assistant|>\n"
return prompt
def test_bigdl_llm(self):
llm = BigdlLLM(
model_name=self.llama_model_path,
tokenizer_name=self.llama_model_path,
context_window=512,
max_new_tokens=32,
model_kwargs={},
generate_kwargs={"temperature": 0.7, "do_sample": False},
messages_to_prompt=self.messages_to_prompt,
completion_to_prompt=self.completion_to_prompt,
device_map="cpu",
)
res = llm.complete("What is AI?")
assert res!=None
if __name__ == '__main__':
pytest.main([__file__])