Llamaindex: add tokenizer_id and support chat (#10590)
* add tokenizer_id * fix * modify * add from_model_id and from_mode_id_low_bit * fix typo and add comment * fix python code style --------- Co-authored-by: pengyb2001 <284261055@qq.com>
This commit is contained in:
parent
10ee786920
commit
9d8ba64c0d
6 changed files with 226 additions and 38 deletions
|
|
@ -4,10 +4,6 @@
|
|||
This folder contains examples showcasing how to use [**LlamaIndex**](https://github.com/run-llama/llama_index) with `ipex-llm`.
|
||||
> [**LlamaIndex**](https://github.com/run-llama/llama_index) is a data framework designed to improve large language models by providing tools for easier data ingestion, management, and application integration.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Ensure `ipex-llm` is installed by following the [IPEX-LLM Installation Guide](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm#install) before proceeding with the examples provided here.
|
||||
|
||||
|
||||
## Retrieval-Augmented Generation (RAG) Example
|
||||
The RAG example ([rag.py](./rag.py)) is adapted from the [Official llama index RAG example](https://docs.llamaindex.ai/en/stable/examples/low_level/oss_ingestion_retrieval.html). This example builds a pipeline to ingest data (e.g. llama2 paper in pdf format) into a vector database (e.g. PostgreSQL), and then build a retrieval pipeline from that vector database.
|
||||
|
|
@ -21,6 +17,10 @@ The RAG example ([rag.py](./rag.py)) is adapted from the [Official llama index R
|
|||
pip install llama-index-readers-file llama-index-vector-stores-postgres llama-index-embeddings-huggingface
|
||||
```
|
||||
|
||||
* **Install IPEX-LLM**
|
||||
Ensure `ipex-llm` is installed by following the [IPEX-LLM Installation Guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Overview/install.html) before proceeding with the examples provided here.
|
||||
|
||||
|
||||
* **Database Setup (using PostgreSQL)**:
|
||||
* Installation:
|
||||
```bash
|
||||
|
|
@ -55,7 +55,7 @@ The RAG example ([rag.py](./rag.py)) is adapted from the [Official llama index R
|
|||
In the current directory, run the example with command:
|
||||
|
||||
```bash
|
||||
python rag.py -m <path_to_model>
|
||||
python rag.py -m <path_to_model> -t <path_to_tokenizer>
|
||||
```
|
||||
**Additional Parameters for Configuration**:
|
||||
- `-m MODEL_PATH`: **Required**, path to the LLM model
|
||||
|
|
@ -65,6 +65,7 @@ python rag.py -m <path_to_model>
|
|||
- `-q QUESTION`: question you want to ask
|
||||
- `-d DATA`: path to source data used for retrieval (in pdf format)
|
||||
- `-n N_PREDICT`: max predict tokens
|
||||
- `-t TOKENIZER_PATH`: **Required**, path to the tokenizer model
|
||||
|
||||
### Example Output
|
||||
|
||||
|
|
|
|||
|
|
@ -163,11 +163,11 @@ def messages_to_prompt(messages):
|
|||
def main(args):
|
||||
embed_model = HuggingFaceEmbedding(model_name=args.embedding_model_path)
|
||||
|
||||
# Use custom LLM in IPEX-LLM
|
||||
from ipex_llm.llamaindex.llms import BigdlLLM
|
||||
llm = BigdlLLM(
|
||||
# Use custom LLM in BigDL
|
||||
from ipex_llm.llamaindex.llms import IpexLLM
|
||||
llm = IpexLLM.from_model_id(
|
||||
model_name=args.model_path,
|
||||
tokenizer_name=args.model_path,
|
||||
tokenizer_name=args.tokenizer_path,
|
||||
context_window=512,
|
||||
max_new_tokens=args.n_predict,
|
||||
generate_kwargs={"temperature": 0.7, "do_sample": False},
|
||||
|
|
@ -244,8 +244,10 @@ if __name__ == "__main__":
|
|||
help="the password of the user in the database")
|
||||
parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en",
|
||||
help="the path to embedding model path")
|
||||
parser.add_argument('-n','--n-predict', type=int, default=32,
|
||||
parser.add_argument('-n','--n-predict', type=int, default=64,
|
||||
help='max number of predict tokens')
|
||||
parser.add_argument('-t','--tokenizer-path',type=str,required=True,
|
||||
help='the path to transformers tokenizer')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
@ -16,9 +16,9 @@ The RAG example ([rag.py](./rag.py)) is adapted from the [Official llama index R
|
|||
```bash
|
||||
pip install llama-index-readers-file llama-index-vector-stores-postgres llama-index-embeddings-huggingface
|
||||
```
|
||||
* **Install Bigdl LLM**
|
||||
* **Install IPEX-LLM**
|
||||
|
||||
Follow the instructions in [GPU Install Guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Overview/install_gpu.html) to install ipex-llm.
|
||||
Follow the instructions in [GPU Install Guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Overview/install.html) to install ipex-llm.
|
||||
|
||||
* **Database Setup (using PostgreSQL)**:
|
||||
* Linux
|
||||
|
|
@ -145,7 +145,7 @@ There is no need to set further environment variables.
|
|||
In the current directory, run the example with command:
|
||||
|
||||
```bash
|
||||
python rag.py -m <path_to_model>
|
||||
python rag.py -m <path_to_model> -t <path_to_tokenizer>
|
||||
```
|
||||
**Additional Parameters for Configuration**:
|
||||
- `-m MODEL_PATH`: **Required**, path to the LLM model
|
||||
|
|
@ -155,6 +155,7 @@ python rag.py -m <path_to_model>
|
|||
- `-q QUESTION`: question you want to ask
|
||||
- `-d DATA`: path to source data used for retrieval (in pdf format)
|
||||
- `-n N_PREDICT`: max predict tokens
|
||||
- `-t TOKENIZER_PATH`: **Required**, path to the tokenizer model
|
||||
|
||||
### 5. Example Output
|
||||
|
||||
|
|
|
|||
|
|
@ -162,11 +162,11 @@ def messages_to_prompt(messages):
|
|||
def main(args):
|
||||
embed_model = HuggingFaceEmbedding(model_name=args.embedding_model_path)
|
||||
|
||||
# Use custom LLM in IPEX-LLM
|
||||
from ipex_llm.llamaindex.llms import BigdlLLM
|
||||
llm = BigdlLLM(
|
||||
# Use custom LLM in BigDL
|
||||
from ipex_llm.llamaindex.llms import IpexLLM
|
||||
llm = IpexLLM.from_model_id(
|
||||
model_name=args.model_path,
|
||||
tokenizer_name=args.model_path,
|
||||
tokenizer_name=args.tokenizer_path,
|
||||
context_window=512,
|
||||
max_new_tokens=args.n_predict,
|
||||
generate_kwargs={"temperature": 0.7, "do_sample": False},
|
||||
|
|
@ -245,6 +245,8 @@ if __name__ == "__main__":
|
|||
help="the path to embedding model path")
|
||||
parser.add_argument('-n','--n-predict', type=int, default=32,
|
||||
help='max number of predict tokens')
|
||||
parser.add_argument('-t','--tokenizer-path',type=str,required=True,
|
||||
help='the path to transformers tokenizer')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
@ -26,9 +26,9 @@ from .bigdlllm import *
|
|||
from llama_index.core.base.llms.base import BaseLLM
|
||||
|
||||
__all__ = [
|
||||
"BigdlLLM",
|
||||
"IpexLLm",
|
||||
]
|
||||
|
||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"BigdlLLM": BigdlLLM,
|
||||
"IpexLLM": IpexLLM,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,6 +70,8 @@ from llama_index.core.llms.custom import CustomLLM
|
|||
|
||||
from llama_index.core.base.llms.generic_utils import (
|
||||
messages_to_prompt as generic_messages_to_prompt,
|
||||
completion_response_to_chat_response,
|
||||
stream_completion_response_to_chat_response
|
||||
)
|
||||
from llama_index.core.prompts.base import PromptTemplate
|
||||
from llama_index.core.types import BaseOutputParser, PydanticProgramMode
|
||||
|
|
@ -85,14 +87,14 @@ DEFAULT_HUGGINGFACE_MODEL = "meta-llama/Llama-2-7b-chat-hf"
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BigdlLLM(CustomLLM):
|
||||
"""Wrapper around the BigDL-LLM
|
||||
class IpexLLM(CustomLLM):
|
||||
"""Wrapper around the IPEX-LLM
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from ipex_llm.llamaindex.llms import BigdlLLM
|
||||
llm = BigdlLLM(model_path="/path/to/llama/model")
|
||||
from ipex_llm.llamaindex.llms import IpexLLM
|
||||
llm = IpexLLM(model_path="/path/to/llama/model")
|
||||
"""
|
||||
|
||||
model_name: str = Field(
|
||||
|
|
@ -171,6 +173,10 @@ class BigdlLLM(CustomLLM):
|
|||
"`messages_to_prompt` that does so."
|
||||
),
|
||||
)
|
||||
load_low_bit: bool = Field(
|
||||
default=False,
|
||||
description="The model is low_bit model or not"
|
||||
)
|
||||
|
||||
_model: Any = PrivateAttr()
|
||||
_tokenizer: Any = PrivateAttr()
|
||||
|
|
@ -198,9 +204,10 @@ class BigdlLLM(CustomLLM):
|
|||
completion_to_prompt: Optional[Callable[[str], str]]=None,
|
||||
pydantic_program_mode: PydanticProgramMode=PydanticProgramMode.DEFAULT,
|
||||
output_parser: Optional[BaseOutputParser] = None,
|
||||
load_low_bit: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Construct BigdlLLM.
|
||||
Construct Ipex-LLM.
|
||||
|
||||
Args:
|
||||
|
||||
|
|
@ -229,6 +236,7 @@ class BigdlLLM(CustomLLM):
|
|||
completion_to_prompt: Function to convert messages to prompt.
|
||||
pydantic_program_mode: DEFAULT.
|
||||
output_parser: BaseOutputParser.
|
||||
load_low_bit: Use low_bit checkpoint.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
|
|
@ -238,15 +246,25 @@ class BigdlLLM(CustomLLM):
|
|||
if model:
|
||||
self._model = model
|
||||
else:
|
||||
try:
|
||||
self._model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, load_in_4bit=True, use_cache=True,
|
||||
trust_remote_code=True, **model_kwargs
|
||||
)
|
||||
except:
|
||||
from ipex_llm.transformers import AutoModel
|
||||
self._model = AutoModel.from_pretrained(model_name,
|
||||
load_in_4bit=True, **model_kwargs)
|
||||
if not load_low_bit:
|
||||
try:
|
||||
self._model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, load_in_4bit=True, use_cache=True,
|
||||
trust_remote_code=True, **model_kwargs
|
||||
)
|
||||
except:
|
||||
from ipex_llm.transformers import AutoModel
|
||||
self._model = AutoModel.from_pretrained(model_name,
|
||||
load_in_4bit=True,
|
||||
**model_kwargs)
|
||||
else:
|
||||
try:
|
||||
self._model = AutoModelForCausalLM.load_low_bit(
|
||||
model_name, use_cache=True,
|
||||
trust_remote_code=True, **model_kwargs)
|
||||
except:
|
||||
from ipex_llm.transformers import AutoModel
|
||||
self._model = AutoModel.load_low_bit(model_name, **model_kwargs)
|
||||
|
||||
if 'xpu' in device_map:
|
||||
self._model = self._model.to(device_map)
|
||||
|
|
@ -271,7 +289,6 @@ class BigdlLLM(CustomLLM):
|
|||
if tokenizer:
|
||||
self._tokenizer = tokenizer
|
||||
else:
|
||||
print(f"load tokenizer: {tokenizer_name}")
|
||||
try:
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_kwargs)
|
||||
except:
|
||||
|
|
@ -327,6 +344,157 @@ class BigdlLLM(CustomLLM):
|
|||
output_parser=output_parser,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_model_id(
|
||||
cls,
|
||||
context_window: int = DEFAULT_CONTEXT_WINDOW,
|
||||
max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
|
||||
query_wrapper_prompt: Union[str, PromptTemplate]="{query_str}",
|
||||
tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL,
|
||||
model_name: str = DEFAULT_HUGGINGFACE_MODEL,
|
||||
model: Optional[Any] = None,
|
||||
tokenizer: Optional[Any] = None,
|
||||
device_map: Optional[str] = "auto",
|
||||
stopping_ids: Optional[List[int]] = None,
|
||||
tokenizer_kwargs: Optional[dict] = None,
|
||||
tokenizer_outputs_to_remove: Optional[list] = None,
|
||||
model_kwargs: Optional[dict] = None,
|
||||
generate_kwargs: Optional[dict] = None,
|
||||
is_chat_model: Optional[bool] = False,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
system_prompt: str = "",
|
||||
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]]=None,
|
||||
completion_to_prompt: Optional[Callable[[str], str]]=None,
|
||||
pydantic_program_mode: PydanticProgramMode=PydanticProgramMode.DEFAULT,
|
||||
output_parser: Optional[BaseOutputParser] = None,
|
||||
):
|
||||
"""
|
||||
Construct IPEX-LLM from HuggingFace Model.
|
||||
|
||||
Args:
|
||||
|
||||
context_window: The maximum number of tokens available for input.
|
||||
max_new_tokens: The maximum number of tokens to generate.
|
||||
query_wrapper_prompt: The query wrapper prompt, containing the query placeholder.
|
||||
Should contain a `{query_str}` placeholder.
|
||||
tokenizer_name: The name of the tokenizer to use from HuggingFace.
|
||||
Unused if `tokenizer` is passed in directly.
|
||||
model_name: The model name to use from HuggingFace.
|
||||
Unused if `model` is passed in directly.
|
||||
model: The HuggingFace model.
|
||||
tokenizer: The tokenizer.
|
||||
device_map: The device_map to use. Defaults to 'auto'.
|
||||
stopping_ids: The stopping ids to use.
|
||||
Generation stops when these token IDs are predicted.
|
||||
tokenizer_kwargs: The kwargs to pass to the tokenizer.
|
||||
tokenizer_outputs_to_remove: The outputs to remove from the tokenizer.
|
||||
Sometimes huggingface tokenizers return extra inputs that cause errors.
|
||||
model_kwargs: The kwargs to pass to the model during initialization.
|
||||
generate_kwargs: The kwargs to pass to the model during generation.
|
||||
is_chat_model: Whether the model is `chat`
|
||||
callback_manager: Callback manager.
|
||||
system_prompt: The system prompt, containing any extra instructions or context.
|
||||
messages_to_prompt: Function to convert messages to prompt.
|
||||
completion_to_prompt: Function to convert messages to prompt.
|
||||
pydantic_program_mode: DEFAULT.
|
||||
output_parser: BaseOutputParser.
|
||||
|
||||
Returns:
|
||||
Ipex-LLM instance.
|
||||
"""
|
||||
return cls(
|
||||
context_window=context_window,
|
||||
max_new_tokens=max_new_tokens,
|
||||
query_wrapper_prompt=query_wrapper_prompt,
|
||||
tokenizer_name=tokenizer_name,
|
||||
model_name=model_name,
|
||||
device_map=device_map,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
model_kwargs=model_kwargs,
|
||||
generate_kwargs=generate_kwargs,
|
||||
is_chat_model=is_chat_model,
|
||||
callback_manager=callback_manager,
|
||||
system_prompt=system_prompt,
|
||||
completion_to_prompt=completion_to_prompt,
|
||||
messages_to_prompt=messages_to_prompt,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_model_id_low_bit(
|
||||
cls,
|
||||
context_window: int = DEFAULT_CONTEXT_WINDOW,
|
||||
max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
|
||||
query_wrapper_prompt: Union[str, PromptTemplate]="{query_str}",
|
||||
tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL,
|
||||
model_name: str = DEFAULT_HUGGINGFACE_MODEL,
|
||||
model: Optional[Any] = None,
|
||||
tokenizer: Optional[Any] = None,
|
||||
device_map: Optional[str] = "auto",
|
||||
stopping_ids: Optional[List[int]] = None,
|
||||
tokenizer_kwargs: Optional[dict] = None,
|
||||
tokenizer_outputs_to_remove: Optional[list] = None,
|
||||
model_kwargs: Optional[dict] = None,
|
||||
generate_kwargs: Optional[dict] = None,
|
||||
is_chat_model: Optional[bool] = False,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
system_prompt: str = "",
|
||||
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]]=None,
|
||||
completion_to_prompt: Optional[Callable[[str], str]]=None,
|
||||
pydantic_program_mode: PydanticProgramMode=PydanticProgramMode.DEFAULT,
|
||||
output_parser: Optional[BaseOutputParser] = None,
|
||||
):
|
||||
"""
|
||||
Construct IPEX-LLM from HuggingFace Model low-bit checkpoint.
|
||||
|
||||
Args:
|
||||
|
||||
context_window: The maximum number of tokens available for input.
|
||||
max_new_tokens: The maximum number of tokens to generate.
|
||||
query_wrapper_prompt: The query wrapper prompt, containing the query placeholder.
|
||||
Should contain a `{query_str}` placeholder.
|
||||
tokenizer_name: The name of the tokenizer to use from HuggingFace.
|
||||
Unused if `tokenizer` is passed in directly.
|
||||
model_name: The model name to use from HuggingFace.
|
||||
Unused if `model` is passed in directly.
|
||||
model: The HuggingFace model.
|
||||
tokenizer: The tokenizer.
|
||||
device_map: The device_map to use. Defaults to 'auto'.
|
||||
stopping_ids: The stopping ids to use.
|
||||
Generation stops when these token IDs are predicted.
|
||||
tokenizer_kwargs: The kwargs to pass to the tokenizer.
|
||||
tokenizer_outputs_to_remove: The outputs to remove from the tokenizer.
|
||||
Sometimes huggingface tokenizers return extra inputs that cause errors.
|
||||
model_kwargs: The kwargs to pass to the model during initialization.
|
||||
generate_kwargs: The kwargs to pass to the model during generation.
|
||||
is_chat_model: Whether the model is `chat`
|
||||
callback_manager: Callback manager.
|
||||
system_prompt: The system prompt, containing any extra instructions or context.
|
||||
messages_to_prompt: Function to convert messages to prompt.
|
||||
completion_to_prompt: Function to convert messages to prompt.
|
||||
pydantic_program_mode: DEFAULT.
|
||||
output_parser: BaseOutputParser.
|
||||
|
||||
Returns:
|
||||
Ipex-LLM instance.
|
||||
"""
|
||||
return cls(
|
||||
context_window=context_window,
|
||||
max_new_tokens=max_new_tokens,
|
||||
query_wrapper_prompt=query_wrapper_prompt,
|
||||
tokenizer_name=tokenizer_name,
|
||||
model_name=model_name,
|
||||
device_map=device_map,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
model_kwargs=model_kwargs,
|
||||
generate_kwargs=generate_kwargs,
|
||||
is_chat_model=is_chat_model,
|
||||
callback_manager=callback_manager,
|
||||
system_prompt=system_prompt,
|
||||
completion_to_prompt=completion_to_prompt,
|
||||
messages_to_prompt=messages_to_prompt,
|
||||
load_low_bit=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
"""
|
||||
|
|
@ -338,7 +506,7 @@ class BigdlLLM(CustomLLM):
|
|||
Returns:
|
||||
Str of class name.
|
||||
"""
|
||||
return "BigDL_LLM"
|
||||
return "IpexLLM"
|
||||
|
||||
@property
|
||||
def metadata(self) -> LLMMetadata:
|
||||
|
|
@ -431,7 +599,7 @@ class BigdlLLM(CustomLLM):
|
|||
Returns:
|
||||
CompletionReponse after generation.
|
||||
"""
|
||||
from transformers import TextStreamer
|
||||
from transformers import TextIteratorStreamer
|
||||
full_prompt = prompt
|
||||
if not formatted:
|
||||
if self.query_wrapper_prompt:
|
||||
|
|
@ -446,9 +614,9 @@ class BigdlLLM(CustomLLM):
|
|||
if key in input_ids:
|
||||
input_ids.pop(key, None)
|
||||
|
||||
streamer = TextStreamer(self._tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
streamer = TextIteratorStreamer(self._tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
generation_kwargs = dict(
|
||||
input_ids,
|
||||
input_ids=input_ids,
|
||||
streamer=streamer,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
stopping_criteria=self._stopping_criteria,
|
||||
|
|
@ -465,3 +633,17 @@ class BigdlLLM(CustomLLM):
|
|||
yield CompletionResponse(text=text, delta=x)
|
||||
|
||||
return gen()
|
||||
|
||||
@llm_chat_callback()
|
||||
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
||||
prompt = self.messages_to_prompt(messages)
|
||||
completion_response = self.complete(prompt, formatted=True, **kwargs)
|
||||
return completion_response_to_chat_response(completion_response)
|
||||
|
||||
@llm_chat_callback()
|
||||
def stream_chat(
|
||||
self, messages: Sequence[ChatMessage], **kwargs: Any
|
||||
) -> ChatResponseGen:
|
||||
prompt = self.messages_to_prompt(messages)
|
||||
completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
|
||||
return stream_completion_response_to_chat_response(completion_response)
|
||||
|
|
|
|||
Loading…
Reference in a new issue