diff --git a/python/llm/example/CPU/LlamaIndex/README.md b/python/llm/example/CPU/LlamaIndex/README.md new file mode 100644 index 00000000..f12e223c --- /dev/null +++ b/python/llm/example/CPU/LlamaIndex/README.md @@ -0,0 +1,60 @@ +# LlamaIndex Examples + +The examples here show how to use LlamaIndex with `bigdl-llm`. +The RAG example is modified from the [demo](https://docs.llamaindex.ai/en/stable/examples/low_level/oss_ingestion_retrieval.html). + +## Install bigdl-llm +Follow the instructions in [Install](https://github.com/intel-analytics/BigDL/tree/main/python/llm#install). + +## Install Required Dependencies for llamaindex examples. + +### Install Site-packages +```bash +pip install llama-index-readers-file +pip install llama-index-vector-stores-postgres +pip install llama-index-embeddings-huggingface +``` + +### Install Postgres +> Note: There are plenty of open-source databases you can use. Here we provide an example using Postgres. +* Download and install postgres by running the commands below. + ```bash + sudo apt-get install postgresql-client + sudo apt-get install postgresql + ``` +* Initilize postgres. + ```bash + sudo su - postgres + psql + ``` + After running the commands in the shell, we reach the console of postgres. Then we can add a role like the following + ```bash + CREATE ROLE WITH LOGIN PASSWORD ''; + ALTER ROLE SUPERUSER; + ``` +* Install pgvector according to the [page](https://github.com/pgvector/pgvector). If you encounter problem about the installation, please refer to the [notes](https://github.com/pgvector/pgvector#installation-notes) which may be helpful. +* Download the database. + ```bash + mkdir data + wget --user-agent "Mozilla" "https://arxiv.org/pdf/2307.09288.pdf" -O "data/llama2.pdf" + ``` + + +## Run the examples + +### Retrieval-augmented Generation +```bash +python rag.py -m MODEL_PATH -e EMBEDDING_MODEL_PATH -u USERNAME -p PASSWORD -q QUESTION -d DATA +``` +arguments info: +- `-m MODEL_PATH`: **required**, path to the llama model +- `-e EMBEDDING_MODEL_PATH`: path to the embedding model +- `-u USERNAME`: username in the postgres database +- `-p PASSWORD`: password in the postgres database +- `-q QUESTION`: question you want to ask +- `-d DATA`: path to data used during retrieval + +Here is the sample output when applying Llama-2-7b-chat-hf as the generatio model when we ask "How does Llama 2 perform compared to other open-source models?" and use llama.pdf as database. +``` +Llama 2 performs better than most open-source models on the benchmarks we tested. Specifically, it outperforms all open-source models on MMLU and BBH, and is close to GPT-3.5 on these benchmarks. Additionally, Llama 2 is on par or better than PaLM-2-L on almost all benchmarks. The only exception is the coding benchmarks, where Llama 2 lags significantly behind GPT-4 and PaLM-2-L. Overall, Llama 2 demonstrates strong performance on a wide range of natural language processing tasks. +``` diff --git a/python/llm/example/CPU/LlamaIndex/rag.py b/python/llm/example/CPU/LlamaIndex/rag.py new file mode 100644 index 00000000..cf11b5ad --- /dev/null +++ b/python/llm/example/CPU/LlamaIndex/rag.py @@ -0,0 +1,247 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from llama_index.embeddings.huggingface import HuggingFaceEmbedding +from sqlalchemy import make_url +from llama_index.vector_stores.postgres import PGVectorStore +# from llama_index.llms.llama_cpp import LlamaCPP +import psycopg2 +from pathlib import Path +from llama_index.readers.file import PyMuPDFReader +from llama_index.core.schema import NodeWithScore +from typing import Optional +from llama_index.core.query_engine import RetrieverQueryEngine +from llama_index.core import QueryBundle +from llama_index.core.retrievers import BaseRetriever +from typing import Any, List +from llama_index.core.node_parser import SentenceSplitter +from llama_index.core.vector_stores import VectorStoreQuery +import argparse + +def load_vector_database(username, password): + db_name = "example_db" + host = "localhost" + password = password + port = "5432" + user = username + # conn = psycopg2.connect(connection_string) + conn = psycopg2.connect( + dbname="postgres", + host=host, + password=password, + port=port, + user=user, + ) + conn.autocommit = True + + with conn.cursor() as c: + c.execute(f"DROP DATABASE IF EXISTS {db_name}") + c.execute(f"CREATE DATABASE {db_name}") + + vector_store = PGVectorStore.from_params( + database=db_name, + host=host, + password=password, + port=port, + user=user, + table_name="llama2_paper", + embed_dim=384, # openai embedding dimension + ) + return vector_store + + +def load_data(data_path): + loader = PyMuPDFReader() + documents = loader.load(file_path=data_path) + + + text_parser = SentenceSplitter( + chunk_size=1024, + # separator=" ", + ) + text_chunks = [] + # maintain relationship with source doc index, to help inject doc metadata in (3) + doc_idxs = [] + for doc_idx, doc in enumerate(documents): + cur_text_chunks = text_parser.split_text(doc.text) + text_chunks.extend(cur_text_chunks) + doc_idxs.extend([doc_idx] * len(cur_text_chunks)) + + from llama_index.core.schema import TextNode + nodes = [] + for idx, text_chunk in enumerate(text_chunks): + node = TextNode( + text=text_chunk, + ) + src_doc = documents[doc_idxs[idx]] + node.metadata = src_doc.metadata + nodes.append(node) + return nodes + + + + + + +class VectorDBRetriever(BaseRetriever): + """Retriever over a postgres vector store.""" + + def __init__( + self, + vector_store: PGVectorStore, + embed_model: Any, + query_mode: str = "default", + similarity_top_k: int = 2, + ) -> None: + """Init params.""" + self._vector_store = vector_store + self._embed_model = embed_model + self._query_mode = query_mode + self._similarity_top_k = similarity_top_k + super().__init__() + + def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + """Retrieve.""" + query_embedding = self._embed_model.get_query_embedding( + query_bundle.query_str + ) + vector_store_query = VectorStoreQuery( + query_embedding=query_embedding, + similarity_top_k=self._similarity_top_k, + mode=self._query_mode, + ) + query_result = self._vector_store.query(vector_store_query) + + nodes_with_scores = [] + for index, node in enumerate(query_result.nodes): + score: Optional[float] = None + if query_result.similarities is not None: + score = query_result.similarities[index] + nodes_with_scores.append(NodeWithScore(node=node, score=score)) + + return nodes_with_scores + +def completion_to_prompt(completion): + return f"<|system|>\n\n<|user|>\n{completion}\n<|assistant|>\n" + + +# Transform a list of chat messages into zephyr-specific input +def messages_to_prompt(messages): + prompt = "" + for message in messages: + if message.role == "system": + prompt += f"<|system|>\n{message.content}\n" + elif message.role == "user": + prompt += f"<|user|>\n{message.content}\n" + elif message.role == "assistant": + prompt += f"<|assistant|>\n{message.content}\n" + + # ensure we start with a system prompt, insert blank if needed + if not prompt.startswith("<|system|>\n"): + prompt = "<|system|>\n\n" + prompt + + # add final assistant prompt + prompt = prompt + "<|assistant|>\n" + + return prompt + +def main(args): + embed_model = HuggingFaceEmbedding(model_name=args.embedding_model_path) + + # Use custom LLM in BigDL + from bigdl.llm.llamaindex.llms import BigdlLLM + llm = BigdlLLM( + model_name=args.model_path, + tokenizer_name=args.model_path, + context_window=512, + max_new_tokens=32, + generate_kwargs={"temperature": 0.7, "do_sample": False}, + model_kwargs={}, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + device_map="cpu", + ) + + vector_store = load_vector_database(username=args.user, password=args.password) + nodes = load_data(data_path=args.data) + for node in nodes: + node_embedding = embed_model.get_text_embedding( + node.get_content(metadata_mode="all") + ) + node.embedding = node_embedding + + vector_store.add(nodes) + + # query_str = "Can you tell me about the key concepts for safety finetuning" + query_str = "Explain about the training data for Llama 2" + query_embedding = embed_model.get_query_embedding(query_str) + # construct vector store query + + + query_mode = "default" + # query_mode = "sparse" + # query_mode = "hybrid" + + vector_store_query = VectorStoreQuery( + query_embedding=query_embedding, similarity_top_k=2, mode=query_mode + ) + # returns a VectorStoreQueryResult + query_result = vector_store.query(vector_store_query) + # print("Retrieval Results: ") + # print(query_result.nodes[0].get_content()) + + + + nodes_with_scores = [] + for index, node in enumerate(query_result.nodes): + score: Optional[float] = None + if query_result.similarities is not None: + score = query_result.similarities[index] + nodes_with_scores.append(NodeWithScore(node=node, score=score)) + + retriever = VectorDBRetriever( + vector_store, embed_model, query_mode="default", similarity_top_k=1 + ) + + + query_engine = RetrieverQueryEngine.from_args(retriever, llm=llm) + + # query_str = "How does Llama 2 perform compared to other open-source models?" + query_str = args.question + response = query_engine.query(query_str) + + + print("------------RESPONSE GENERATION---------------------") + print(str(response)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='LlamaIndex BigdlLLM Example') + parser.add_argument('-m','--model-path', type=str, required=True, + help='the path to transformers model') + parser.add_argument('-q', '--question', type=str, default='How does Llama 2 perform compared to other open-source models?', + help='qustion you want to ask.') + parser.add_argument('-d','--data',type=str, default='./data/llama2.pdf', + help="the data used during retrieval") + parser.add_argument('-u', '--user', type=str, required=True, + help="user name in the database postgres") + parser.add_argument('-p','--password', type=str, required=True, + help="the password of the user in the database") + parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en", + help="the path to embedding model path") + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/python/llm/src/bigdl/llm/llamaindex/__init__.py b/python/llm/src/bigdl/llm/llamaindex/__init__.py new file mode 100644 index 00000000..dbdafd2a --- /dev/null +++ b/python/llm/src/bigdl/llm/llamaindex/__init__.py @@ -0,0 +1,20 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This would makes sure Python is aware there is more than one sub-package within bigdl, +# physically located elsewhere. +# Otherwise there would be module not found error in non-pip's setting as Python would +# only search the first bigdl package and end up finding only one sub-package. diff --git a/python/llm/src/bigdl/llm/llamaindex/llms/__init__.py b/python/llm/src/bigdl/llm/llamaindex/llms/__init__.py new file mode 100644 index 00000000..54fc4aa5 --- /dev/null +++ b/python/llm/src/bigdl/llm/llamaindex/llms/__init__.py @@ -0,0 +1,34 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This would makes sure Python is aware there is more than one sub-package within bigdl, +# physically located elsewhere. +# Otherwise there would be module not found error in non-pip's setting as Python would +# only search the first bigdl package and end up finding only one sub-package. + +"""Wrappers on top of large language models APIs.""" +from typing import Dict, Type + +from .bigdlllm import * +from llama_index.core.base.llms.base import BaseLLM + +__all__ = [ + "BigdlLLM", +] + +type_to_cls_dict: Dict[str, Type[BaseLLM]] = { + "BigdlLLM": BigdlLLM, +} diff --git a/python/llm/src/bigdl/llm/llamaindex/llms/bigdlllm.py b/python/llm/src/bigdl/llm/llamaindex/llms/bigdlllm.py new file mode 100644 index 00000000..81a5ce26 --- /dev/null +++ b/python/llm/src/bigdl/llm/llamaindex/llms/bigdlllm.py @@ -0,0 +1,449 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# The file is modified from +# https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/base/llms/base.py + +# The MIT License + +# Copyright (c) Harrison Chase + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import logging +from threading import Thread +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import torch +from huggingface_hub import AsyncInferenceClient, InferenceClient, model_info +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) +from llama_index.core.bridge.pydantic import Field, PrivateAttr +from llama_index.core.callbacks import CallbackManager +from llama_index.core.constants import ( + DEFAULT_CONTEXT_WINDOW, + DEFAULT_NUM_OUTPUTS, +) +from llama_index.core.llms.callbacks import ( + llm_chat_callback, + llm_completion_callback, +) +from llama_index.core.llms.custom import CustomLLM + +from llama_index.core.base.llms.generic_utils import ( + messages_to_prompt as generic_messages_to_prompt, +) +from llama_index.core.prompts.base import PromptTemplate +from llama_index.core.types import BaseOutputParser, PydanticProgramMode +from transformers import ( + StoppingCriteria, + StoppingCriteriaList, +) +from transformers import AutoTokenizer, LlamaTokenizer + + +DEFAULT_HUGGINGFACE_MODEL = "meta-llama/Llama-2-7b-chat-hf" + +logger = logging.getLogger(__name__) + + +class BigdlLLM(CustomLLM): + """Wrapper around the BigDL-LLM + + Example: + .. code-block:: python + + from bigdl.llm.llamaindex.llms import BigdlLLM + llm = BigdlLLM(model_path="/path/to/llama/model") + """ + + model_name: str = Field( + default=DEFAULT_HUGGINGFACE_MODEL, + description=( + "The model name to use from HuggingFace. " + "Unused if `model` is passed in directly." + ), + ) + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description="The maximum number of tokens available for input.", + gt=0, + ) + max_new_tokens: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description="The maximum number of tokens to generate.", + gt=0, + ) + system_prompt: str = Field( + default="", + description=( + "The system prompt, containing any extra instructions or context. " + "The model card on HuggingFace should specify if this is needed." + ), + ) + query_wrapper_prompt: PromptTemplate = Field( + default=PromptTemplate("{query_str}"), + description=( + "The query wrapper prompt, containing the query placeholder. " + "The model card on HuggingFace should specify if this is needed. " + "Should contain a `{query_str}` placeholder." + ), + ) + tokenizer_name: str = Field( + default=DEFAULT_HUGGINGFACE_MODEL, + description=( + "The name of the tokenizer to use from HuggingFace. " + "Unused if `tokenizer` is passed in directly." + ), + ) + device_map: str = Field( + default="auto", description="The device_map to use. Defaults to 'auto'." + ) + stopping_ids: List[int] = Field( + default_factory=list, + description=( + "The stopping ids to use. " + "Generation stops when these token IDs are predicted." + ), + ) + tokenizer_outputs_to_remove: list = Field( + default_factory=list, + description=( + "The outputs to remove from the tokenizer. " + "Sometimes huggingface tokenizers return extra inputs that cause errors." + ), + ) + tokenizer_kwargs: dict = Field( + default_factory=dict, description="The kwargs to pass to the tokenizer." + ) + model_kwargs: dict = Field( + default_factory=dict, + description="The kwargs to pass to the model during initialization.", + ) + generate_kwargs: dict = Field( + default_factory=dict, + description="The kwargs to pass to the model during generation.", + ) + is_chat_model: bool = Field( + default=False, + description=( + LLMMetadata.__fields__["is_chat_model"].field_info.description + + " Be sure to verify that you either pass an appropriate tokenizer " + "that can convert prompts to properly formatted chat messages or a " + "`messages_to_prompt` that does so." + ), + ) + + _model: Any = PrivateAttr() + _tokenizer: Any = PrivateAttr() + _stopping_criteria: Any = PrivateAttr() + + def __init__( + self, + context_window: int = DEFAULT_CONTEXT_WINDOW, + max_new_tokens: int = DEFAULT_NUM_OUTPUTS, + query_wrapper_prompt: Union[str, PromptTemplate]="{query_str}", + tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL, + model_name: str = DEFAULT_HUGGINGFACE_MODEL, + model: Optional[Any] = None, + tokenizer: Optional[Any] = None, + device_map: Optional[str] = "auto", + stopping_ids: Optional[List[int]] = None, + tokenizer_kwargs: Optional[dict] = None, + tokenizer_outputs_to_remove: Optional[list] = None, + model_kwargs: Optional[dict] = None, + generate_kwargs: Optional[dict] = None, + is_chat_model: Optional[bool] = False, + callback_manager: Optional[CallbackManager] = None, + system_prompt: str = "", + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]]=None, + completion_to_prompt: Optional[Callable[[str], str]]=None, + pydantic_program_mode: PydanticProgramMode=PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, + ) -> None: + """ + Construct BigdlLLM. + + Args: + + context_window: The maximum number of tokens available for input. + max_new_tokens: The maximum number of tokens to generate. + query_wrapper_prompt: The query wrapper prompt, containing the query placeholder. + Should contain a `{query_str}` placeholder. + tokenizer_name: The name of the tokenizer to use from HuggingFace. + Unused if `tokenizer` is passed in directly. + model_name: The model name to use from HuggingFace. + Unused if `model` is passed in directly. + model: The HuggingFace model. + tokenizer: The tokenizer. + device_map: The device_map to use. Defaults to 'auto'. + stopping_ids: The stopping ids to use. + Generation stops when these token IDs are predicted. + tokenizer_kwargs: The kwargs to pass to the tokenizer. + tokenizer_outputs_to_remove: The outputs to remove from the tokenizer. + Sometimes huggingface tokenizers return extra inputs that cause errors. + model_kwargs: The kwargs to pass to the model during initialization. + generate_kwargs: The kwargs to pass to the model during generation. + is_chat_model: Whether the model is `chat` + callback_manager: Callback manager. + system_prompt: The system prompt, containing any extra instructions or context. + messages_to_prompt: Function to convert messages to prompt. + completion_to_prompt: Function to convert messages to prompt. + pydantic_program_mode: DEFAULT. + output_parser: BaseOutputParser. + + Returns: + None. + """ + model_kwargs = model_kwargs or {} + from bigdl.llm.transformers import AutoModelForCausalLM + self._model = model or AutoModelForCausalLM.from_pretrained( + model_name, load_in_4bit=True, **model_kwargs + ) + + # check context_window + config_dict = self._model.config.to_dict() + model_context_window = int( + config_dict.get("max_position_embeddings", context_window) + ) + if model_context_window and model_context_window < context_window: + logger.warning( + f"Supplied context_window {context_window} is greater " + f"than the model's max input size {model_context_window}. " + "Disable this warning by setting a lower context_window." + ) + context_window = model_context_window + + tokenizer_kwargs = tokenizer_kwargs or {} + if "max_length" not in tokenizer_kwargs: + tokenizer_kwargs["max_length"] = context_window + + self._tokenizer = tokenizer or AutoTokenizer.from_pretrained( + tokenizer_name, **tokenizer_kwargs + ) + + if tokenizer_name != model_name: + logger.warning( + f"The model `{model_name}` and tokenizer `{tokenizer_name}` " + f"are different, please ensure that they are compatible." + ) + + # setup stopping criteria + stopping_ids_list = stopping_ids or [] + + class StopOnTokens(StoppingCriteria): + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + **kwargs: Any, + ) -> bool: + for stop_id in stopping_ids_list: + if input_ids[0][-1] == stop_id: + return True + return False + + self._stopping_criteria = StoppingCriteriaList([StopOnTokens()]) + + if isinstance(query_wrapper_prompt, str): + query_wrapper_prompt = PromptTemplate(query_wrapper_prompt) + + messages_to_prompt = messages_to_prompt or self._tokenizer_messages_to_prompt + + super().__init__( + context_window=context_window, + max_new_tokens=max_new_tokens, + query_wrapper_prompt=query_wrapper_prompt, + tokenizer_name=tokenizer_name, + model_name=model_name, + device_map=device_map, + stopping_ids=stopping_ids or [], + tokenizer_kwargs=tokenizer_kwargs or {}, + tokenizer_outputs_to_remove=tokenizer_outputs_to_remove or [], + model_kwargs=model_kwargs or {}, + generate_kwargs=generate_kwargs or {}, + is_chat_model=is_chat_model, + callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, + ) + + @classmethod + def class_name(cls) -> str: + """ + Get class name. + + Args: + + + Returns: + Str of class name. + """ + return "BigDL_LLM" + + @property + def metadata(self) -> LLMMetadata: + """ + Get meta data. + + Args: + + Returns: + LLMmetadata contains context_window, + num_output, model_name, is_chat_model. + """ + return LLMMetadata( + context_window=self.context_window, + num_output=self.max_new_tokens, + model_name=self.model_name, + is_chat_model=self.is_chat_model, + ) + + def _tokenizer_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: + """ + Use the tokenizer to convert messages to prompt. Fallback to generic. + + Args: + messages: Sequence of ChatMessage. + + Returns: + Str of response. + """ + if hasattr(self._tokenizer, "apply_chat_template"): + messages_dict = [ + {"role": message.role.value, "content": message.content} + for message in messages + ] + tokens = self._tokenizer.apply_chat_template(messages_dict) + return self._tokenizer.decode(tokens) + + return generic_messages_to_prompt(messages) + + @llm_completion_callback() + def complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + """ + Complete by LLM. + + Args: + prompt: Prompt for completion. + formatted: Whether the prompt is formatted by wrapper. + kwargs: Other kwargs for complete. + + Returns: + CompletionReponse after generation. + """ + full_prompt = prompt + if not formatted: + if self.query_wrapper_prompt: + full_prompt = self.query_wrapper_prompt.format(query_str=prompt) + if self.system_prompt: + full_prompt = f"{self.system_prompt} {full_prompt}" + input_ids = self._tokenizer(full_prompt, return_tensors="pt") + input_ids = input_ids.to(self._model.device) + # remove keys from the tokenizer if needed, to avoid HF errors + for key in self.tokenizer_outputs_to_remove: + if key in input_ids: + input_ids.pop(key, None) + tokens = self._model.generate( + **input_ids, + max_new_tokens=self.max_new_tokens, + stopping_criteria=self._stopping_criteria, + **self.generate_kwargs, + ) + completion_tokens = tokens[0][input_ids["input_ids"].size(1):] + completion = self._tokenizer.decode(completion_tokens, skip_special_tokens=True) + + return CompletionResponse(text=completion, raw={"model_output": tokens}) + + @llm_completion_callback() + def stream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + """ + Complete by LLM in stream. + + Args: + prompt: Prompt for completion. + formatted: Whether the prompt is formatted by wrapper. + kwargs: Other kwargs for complete. + + Returns: + CompletionReponse after generation. + """ + from transformers import TextStreamer + full_prompt = prompt + if not formatted: + if self.query_wrapper_prompt: + full_prompt = self.query_wrapper_prompt.format(query_str=prompt) + if self.system_prompt: + full_prompt = f"{self.system_prompt} {full_prompt}" + + input_ids = self._tokenizer.encode(full_prompt, return_tensors="pt") + input_ids = input_ids.to(self._model.device) + + for key in self.tokenizer_outputs_to_remove: + if key in input_ids: + input_ids.pop(key, None) + + streamer = TextStreamer(self._tokenizer, skip_prompt=True, skip_special_tokens=True) + generation_kwargs = dict( + input_ids, + streamer=streamer, + max_new_tokens=self.max_new_tokens, + stopping_criteria=self._stopping_criteria, + **self.generate_kwargs, + ) + thread = Thread(target=self._model.generate, kwargs=generation_kwargs) + thread.start() + + # create generator based off of streamer + def gen() -> CompletionResponseGen: + text = "" + for x in streamer: + text += x + yield CompletionResponse(text=text, delta=x) + + return gen() diff --git a/python/llm/test/llamaindex/test_llamaindex.py b/python/llm/test/llamaindex/test_llamaindex.py new file mode 100644 index 00000000..6df7293e --- /dev/null +++ b/python/llm/test/llamaindex/test_llamaindex.py @@ -0,0 +1,88 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from bigdl.llm.langchain.llms import TransformersLLM, TransformersPipelineLLM, \ + LlamaLLM, BloomLLM +from bigdl.llm.langchain.embeddings import TransformersEmbeddings, LlamaEmbeddings, \ + BloomEmbeddings + + +from langchain.document_loaders import WebBaseLoader +from langchain.indexes import VectorstoreIndexCreator + + +from langchain.chains.question_answering import load_qa_chain +from langchain.chains.chat_vector_db.prompts import (CONDENSE_QUESTION_PROMPT, + QA_PROMPT) +from langchain.text_splitter import CharacterTextSplitter +from langchain.vectorstores import Chroma + +import pytest +from unittest import TestCase +import os +from bigdl.llm.llamaindex.llms import BigdlLLM + +class Test_LlamaIndex_Transformers_API(TestCase): + def setUp(self): + self.auto_model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH') + self.auto_causal_model_path = os.environ.get('ORIGINAL_REPLIT_CODE_PATH') + self.llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH') + self.bloom_model_path = os.environ.get('BLOOM_ORIGIN_PATH') + thread_num = os.environ.get('THREAD_NUM') + if thread_num is not None: + self.n_threads = int(thread_num) + else: + self.n_threads = 2 + + def completion_to_prompt(completion): + return f"<|system|>\n\n<|user|>\n{completion}\n<|assistant|>\n" + + def messages_to_prompt(messages): + prompt = "" + for message in messages: + if message.role == "system": + prompt += f"<|system|>\n{message.content}\n" + elif message.role == "user": + prompt += f"<|user|>\n{message.content}\n" + elif message.role == "assistant": + prompt += f"<|assistant|>\n{message.content}\n" + + # ensure we start with a system prompt, insert blank if needed + if not prompt.startswith("<|system|>\n"): + prompt = "<|system|>\n\n" + prompt + + # add final assistant prompt + prompt = prompt + "<|assistant|>\n" + return prompt + + def test_bigdl_llm(self): + llm = BigdlLLM( + model_name=self.llama_model_path, + tokenizer_name=self.llama_model_path, + context_window=512, + max_new_tokens=32, + model_kwargs={}, + generate_kwargs={"temperature": 0.7, "do_sample": False}, + messages_to_prompt=self.messages_to_prompt, + completion_to_prompt=self.completion_to_prompt, + device_map="cpu", + ) + res = llm.complete("What is AI?") + assert res!=None + + +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file