Llamaindex: add tokenizer_id and support chat (#10590)

* add tokenizer_id

* fix

* modify

* add from_model_id and from_mode_id_low_bit

* fix typo and add comment

* fix python code style

---------

Co-authored-by: pengyb2001 <284261055@qq.com>
This commit is contained in:
Zhicun 2024-04-07 13:51:34 +08:00 committed by GitHub
parent 10ee786920
commit 9d8ba64c0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 226 additions and 38 deletions

View file

@ -4,10 +4,6 @@
This folder contains examples showcasing how to use [**LlamaIndex**](https://github.com/run-llama/llama_index) with `ipex-llm`.
> [**LlamaIndex**](https://github.com/run-llama/llama_index) is a data framework designed to improve large language models by providing tools for easier data ingestion, management, and application integration.
## Prerequisites
Ensure `ipex-llm` is installed by following the [IPEX-LLM Installation Guide](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm#install) before proceeding with the examples provided here.
## Retrieval-Augmented Generation (RAG) Example
The RAG example ([rag.py](./rag.py)) is adapted from the [Official llama index RAG example](https://docs.llamaindex.ai/en/stable/examples/low_level/oss_ingestion_retrieval.html). This example builds a pipeline to ingest data (e.g. llama2 paper in pdf format) into a vector database (e.g. PostgreSQL), and then build a retrieval pipeline from that vector database.
@ -21,6 +17,10 @@ The RAG example ([rag.py](./rag.py)) is adapted from the [Official llama index R
pip install llama-index-readers-file llama-index-vector-stores-postgres llama-index-embeddings-huggingface
```
* **Install IPEX-LLM**
Ensure `ipex-llm` is installed by following the [IPEX-LLM Installation Guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Overview/install.html) before proceeding with the examples provided here.
* **Database Setup (using PostgreSQL)**:
* Installation:
```bash
@ -55,7 +55,7 @@ The RAG example ([rag.py](./rag.py)) is adapted from the [Official llama index R
In the current directory, run the example with command:
```bash
python rag.py -m <path_to_model>
python rag.py -m <path_to_model> -t <path_to_tokenizer>
```
**Additional Parameters for Configuration**:
- `-m MODEL_PATH`: **Required**, path to the LLM model
@ -65,6 +65,7 @@ python rag.py -m <path_to_model>
- `-q QUESTION`: question you want to ask
- `-d DATA`: path to source data used for retrieval (in pdf format)
- `-n N_PREDICT`: max predict tokens
- `-t TOKENIZER_PATH`: **Required**, path to the tokenizer model
### Example Output

View file

@ -163,11 +163,11 @@ def messages_to_prompt(messages):
def main(args):
embed_model = HuggingFaceEmbedding(model_name=args.embedding_model_path)
# Use custom LLM in IPEX-LLM
from ipex_llm.llamaindex.llms import BigdlLLM
llm = BigdlLLM(
# Use custom LLM in BigDL
from ipex_llm.llamaindex.llms import IpexLLM
llm = IpexLLM.from_model_id(
model_name=args.model_path,
tokenizer_name=args.model_path,
tokenizer_name=args.tokenizer_path,
context_window=512,
max_new_tokens=args.n_predict,
generate_kwargs={"temperature": 0.7, "do_sample": False},
@ -244,8 +244,10 @@ if __name__ == "__main__":
help="the password of the user in the database")
parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en",
help="the path to embedding model path")
parser.add_argument('-n','--n-predict', type=int, default=32,
parser.add_argument('-n','--n-predict', type=int, default=64,
help='max number of predict tokens')
parser.add_argument('-t','--tokenizer-path',type=str,required=True,
help='the path to transformers tokenizer')
args = parser.parse_args()
main(args)

View file

@ -16,9 +16,9 @@ The RAG example ([rag.py](./rag.py)) is adapted from the [Official llama index R
```bash
pip install llama-index-readers-file llama-index-vector-stores-postgres llama-index-embeddings-huggingface
```
* **Install Bigdl LLM**
* **Install IPEX-LLM**
Follow the instructions in [GPU Install Guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Overview/install_gpu.html) to install ipex-llm.
Follow the instructions in [GPU Install Guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Overview/install.html) to install ipex-llm.
* **Database Setup (using PostgreSQL)**:
* Linux
@ -145,7 +145,7 @@ There is no need to set further environment variables.
In the current directory, run the example with command:
```bash
python rag.py -m <path_to_model>
python rag.py -m <path_to_model> -t <path_to_tokenizer>
```
**Additional Parameters for Configuration**:
- `-m MODEL_PATH`: **Required**, path to the LLM model
@ -155,6 +155,7 @@ python rag.py -m <path_to_model>
- `-q QUESTION`: question you want to ask
- `-d DATA`: path to source data used for retrieval (in pdf format)
- `-n N_PREDICT`: max predict tokens
- `-t TOKENIZER_PATH`: **Required**, path to the tokenizer model
### 5. Example Output

View file

@ -162,11 +162,11 @@ def messages_to_prompt(messages):
def main(args):
embed_model = HuggingFaceEmbedding(model_name=args.embedding_model_path)
# Use custom LLM in IPEX-LLM
from ipex_llm.llamaindex.llms import BigdlLLM
llm = BigdlLLM(
# Use custom LLM in BigDL
from ipex_llm.llamaindex.llms import IpexLLM
llm = IpexLLM.from_model_id(
model_name=args.model_path,
tokenizer_name=args.model_path,
tokenizer_name=args.tokenizer_path,
context_window=512,
max_new_tokens=args.n_predict,
generate_kwargs={"temperature": 0.7, "do_sample": False},
@ -245,6 +245,8 @@ if __name__ == "__main__":
help="the path to embedding model path")
parser.add_argument('-n','--n-predict', type=int, default=32,
help='max number of predict tokens')
parser.add_argument('-t','--tokenizer-path',type=str,required=True,
help='the path to transformers tokenizer')
args = parser.parse_args()
main(args)

View file

@ -26,9 +26,9 @@ from .bigdlllm import *
from llama_index.core.base.llms.base import BaseLLM
__all__ = [
"BigdlLLM",
"IpexLLm",
]
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"BigdlLLM": BigdlLLM,
"IpexLLM": IpexLLM,
}

View file

@ -70,6 +70,8 @@ from llama_index.core.llms.custom import CustomLLM
from llama_index.core.base.llms.generic_utils import (
messages_to_prompt as generic_messages_to_prompt,
completion_response_to_chat_response,
stream_completion_response_to_chat_response
)
from llama_index.core.prompts.base import PromptTemplate
from llama_index.core.types import BaseOutputParser, PydanticProgramMode
@ -85,14 +87,14 @@ DEFAULT_HUGGINGFACE_MODEL = "meta-llama/Llama-2-7b-chat-hf"
logger = logging.getLogger(__name__)
class BigdlLLM(CustomLLM):
"""Wrapper around the BigDL-LLM
class IpexLLM(CustomLLM):
"""Wrapper around the IPEX-LLM
Example:
.. code-block:: python
from ipex_llm.llamaindex.llms import BigdlLLM
llm = BigdlLLM(model_path="/path/to/llama/model")
from ipex_llm.llamaindex.llms import IpexLLM
llm = IpexLLM(model_path="/path/to/llama/model")
"""
model_name: str = Field(
@ -171,6 +173,10 @@ class BigdlLLM(CustomLLM):
"`messages_to_prompt` that does so."
),
)
load_low_bit: bool = Field(
default=False,
description="The model is low_bit model or not"
)
_model: Any = PrivateAttr()
_tokenizer: Any = PrivateAttr()
@ -198,9 +204,10 @@ class BigdlLLM(CustomLLM):
completion_to_prompt: Optional[Callable[[str], str]]=None,
pydantic_program_mode: PydanticProgramMode=PydanticProgramMode.DEFAULT,
output_parser: Optional[BaseOutputParser] = None,
load_low_bit: bool = False
) -> None:
"""
Construct BigdlLLM.
Construct Ipex-LLM.
Args:
@ -229,6 +236,7 @@ class BigdlLLM(CustomLLM):
completion_to_prompt: Function to convert messages to prompt.
pydantic_program_mode: DEFAULT.
output_parser: BaseOutputParser.
load_low_bit: Use low_bit checkpoint.
Returns:
None.
@ -238,15 +246,25 @@ class BigdlLLM(CustomLLM):
if model:
self._model = model
else:
try:
self._model = AutoModelForCausalLM.from_pretrained(
model_name, load_in_4bit=True, use_cache=True,
trust_remote_code=True, **model_kwargs
)
except:
from ipex_llm.transformers import AutoModel
self._model = AutoModel.from_pretrained(model_name,
load_in_4bit=True, **model_kwargs)
if not load_low_bit:
try:
self._model = AutoModelForCausalLM.from_pretrained(
model_name, load_in_4bit=True, use_cache=True,
trust_remote_code=True, **model_kwargs
)
except:
from ipex_llm.transformers import AutoModel
self._model = AutoModel.from_pretrained(model_name,
load_in_4bit=True,
**model_kwargs)
else:
try:
self._model = AutoModelForCausalLM.load_low_bit(
model_name, use_cache=True,
trust_remote_code=True, **model_kwargs)
except:
from ipex_llm.transformers import AutoModel
self._model = AutoModel.load_low_bit(model_name, **model_kwargs)
if 'xpu' in device_map:
self._model = self._model.to(device_map)
@ -271,7 +289,6 @@ class BigdlLLM(CustomLLM):
if tokenizer:
self._tokenizer = tokenizer
else:
print(f"load tokenizer: {tokenizer_name}")
try:
self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_kwargs)
except:
@ -327,6 +344,157 @@ class BigdlLLM(CustomLLM):
output_parser=output_parser,
)
@classmethod
def from_model_id(
cls,
context_window: int = DEFAULT_CONTEXT_WINDOW,
max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
query_wrapper_prompt: Union[str, PromptTemplate]="{query_str}",
tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL,
model_name: str = DEFAULT_HUGGINGFACE_MODEL,
model: Optional[Any] = None,
tokenizer: Optional[Any] = None,
device_map: Optional[str] = "auto",
stopping_ids: Optional[List[int]] = None,
tokenizer_kwargs: Optional[dict] = None,
tokenizer_outputs_to_remove: Optional[list] = None,
model_kwargs: Optional[dict] = None,
generate_kwargs: Optional[dict] = None,
is_chat_model: Optional[bool] = False,
callback_manager: Optional[CallbackManager] = None,
system_prompt: str = "",
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]]=None,
completion_to_prompt: Optional[Callable[[str], str]]=None,
pydantic_program_mode: PydanticProgramMode=PydanticProgramMode.DEFAULT,
output_parser: Optional[BaseOutputParser] = None,
):
"""
Construct IPEX-LLM from HuggingFace Model.
Args:
context_window: The maximum number of tokens available for input.
max_new_tokens: The maximum number of tokens to generate.
query_wrapper_prompt: The query wrapper prompt, containing the query placeholder.
Should contain a `{query_str}` placeholder.
tokenizer_name: The name of the tokenizer to use from HuggingFace.
Unused if `tokenizer` is passed in directly.
model_name: The model name to use from HuggingFace.
Unused if `model` is passed in directly.
model: The HuggingFace model.
tokenizer: The tokenizer.
device_map: The device_map to use. Defaults to 'auto'.
stopping_ids: The stopping ids to use.
Generation stops when these token IDs are predicted.
tokenizer_kwargs: The kwargs to pass to the tokenizer.
tokenizer_outputs_to_remove: The outputs to remove from the tokenizer.
Sometimes huggingface tokenizers return extra inputs that cause errors.
model_kwargs: The kwargs to pass to the model during initialization.
generate_kwargs: The kwargs to pass to the model during generation.
is_chat_model: Whether the model is `chat`
callback_manager: Callback manager.
system_prompt: The system prompt, containing any extra instructions or context.
messages_to_prompt: Function to convert messages to prompt.
completion_to_prompt: Function to convert messages to prompt.
pydantic_program_mode: DEFAULT.
output_parser: BaseOutputParser.
Returns:
Ipex-LLM instance.
"""
return cls(
context_window=context_window,
max_new_tokens=max_new_tokens,
query_wrapper_prompt=query_wrapper_prompt,
tokenizer_name=tokenizer_name,
model_name=model_name,
device_map=device_map,
tokenizer_kwargs=tokenizer_kwargs,
model_kwargs=model_kwargs,
generate_kwargs=generate_kwargs,
is_chat_model=is_chat_model,
callback_manager=callback_manager,
system_prompt=system_prompt,
completion_to_prompt=completion_to_prompt,
messages_to_prompt=messages_to_prompt,
)
@classmethod
def from_model_id_low_bit(
cls,
context_window: int = DEFAULT_CONTEXT_WINDOW,
max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
query_wrapper_prompt: Union[str, PromptTemplate]="{query_str}",
tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL,
model_name: str = DEFAULT_HUGGINGFACE_MODEL,
model: Optional[Any] = None,
tokenizer: Optional[Any] = None,
device_map: Optional[str] = "auto",
stopping_ids: Optional[List[int]] = None,
tokenizer_kwargs: Optional[dict] = None,
tokenizer_outputs_to_remove: Optional[list] = None,
model_kwargs: Optional[dict] = None,
generate_kwargs: Optional[dict] = None,
is_chat_model: Optional[bool] = False,
callback_manager: Optional[CallbackManager] = None,
system_prompt: str = "",
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]]=None,
completion_to_prompt: Optional[Callable[[str], str]]=None,
pydantic_program_mode: PydanticProgramMode=PydanticProgramMode.DEFAULT,
output_parser: Optional[BaseOutputParser] = None,
):
"""
Construct IPEX-LLM from HuggingFace Model low-bit checkpoint.
Args:
context_window: The maximum number of tokens available for input.
max_new_tokens: The maximum number of tokens to generate.
query_wrapper_prompt: The query wrapper prompt, containing the query placeholder.
Should contain a `{query_str}` placeholder.
tokenizer_name: The name of the tokenizer to use from HuggingFace.
Unused if `tokenizer` is passed in directly.
model_name: The model name to use from HuggingFace.
Unused if `model` is passed in directly.
model: The HuggingFace model.
tokenizer: The tokenizer.
device_map: The device_map to use. Defaults to 'auto'.
stopping_ids: The stopping ids to use.
Generation stops when these token IDs are predicted.
tokenizer_kwargs: The kwargs to pass to the tokenizer.
tokenizer_outputs_to_remove: The outputs to remove from the tokenizer.
Sometimes huggingface tokenizers return extra inputs that cause errors.
model_kwargs: The kwargs to pass to the model during initialization.
generate_kwargs: The kwargs to pass to the model during generation.
is_chat_model: Whether the model is `chat`
callback_manager: Callback manager.
system_prompt: The system prompt, containing any extra instructions or context.
messages_to_prompt: Function to convert messages to prompt.
completion_to_prompt: Function to convert messages to prompt.
pydantic_program_mode: DEFAULT.
output_parser: BaseOutputParser.
Returns:
Ipex-LLM instance.
"""
return cls(
context_window=context_window,
max_new_tokens=max_new_tokens,
query_wrapper_prompt=query_wrapper_prompt,
tokenizer_name=tokenizer_name,
model_name=model_name,
device_map=device_map,
tokenizer_kwargs=tokenizer_kwargs,
model_kwargs=model_kwargs,
generate_kwargs=generate_kwargs,
is_chat_model=is_chat_model,
callback_manager=callback_manager,
system_prompt=system_prompt,
completion_to_prompt=completion_to_prompt,
messages_to_prompt=messages_to_prompt,
load_low_bit=True,
)
@classmethod
def class_name(cls) -> str:
"""
@ -338,7 +506,7 @@ class BigdlLLM(CustomLLM):
Returns:
Str of class name.
"""
return "BigDL_LLM"
return "IpexLLM"
@property
def metadata(self) -> LLMMetadata:
@ -431,7 +599,7 @@ class BigdlLLM(CustomLLM):
Returns:
CompletionReponse after generation.
"""
from transformers import TextStreamer
from transformers import TextIteratorStreamer
full_prompt = prompt
if not formatted:
if self.query_wrapper_prompt:
@ -446,9 +614,9 @@ class BigdlLLM(CustomLLM):
if key in input_ids:
input_ids.pop(key, None)
streamer = TextStreamer(self._tokenizer, skip_prompt=True, skip_special_tokens=True)
streamer = TextIteratorStreamer(self._tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
input_ids,
input_ids=input_ids,
streamer=streamer,
max_new_tokens=self.max_new_tokens,
stopping_criteria=self._stopping_criteria,
@ -465,3 +633,17 @@ class BigdlLLM(CustomLLM):
yield CompletionResponse(text=text, delta=x)
return gen()
@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
prompt = self.messages_to_prompt(messages)
completion_response = self.complete(prompt, formatted=True, **kwargs)
return completion_response_to_chat_response(completion_response)
@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
prompt = self.messages_to_prompt(messages)
completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
return stream_completion_response_to_chat_response(completion_response)