From 9d8ba64c0db24f79fe614919c60c28428f5d0969 Mon Sep 17 00:00:00 2001 From: Zhicun <59141989+ivy-lv11@users.noreply.github.com> Date: Sun, 7 Apr 2024 13:51:34 +0800 Subject: [PATCH] Llamaindex: add tokenizer_id and support chat (#10590) * add tokenizer_id * fix * modify * add from_model_id and from_mode_id_low_bit * fix typo and add comment * fix python code style --------- Co-authored-by: pengyb2001 <284261055@qq.com> --- python/llm/example/CPU/LlamaIndex/README.md | 11 +- python/llm/example/CPU/LlamaIndex/rag.py | 12 +- python/llm/example/GPU/LlamaIndex/README.md | 7 +- python/llm/example/GPU/LlamaIndex/rag.py | 10 +- .../src/ipex_llm/llamaindex/llms/__init__.py | 4 +- .../src/ipex_llm/llamaindex/llms/bigdlllm.py | 220 ++++++++++++++++-- 6 files changed, 226 insertions(+), 38 deletions(-) diff --git a/python/llm/example/CPU/LlamaIndex/README.md b/python/llm/example/CPU/LlamaIndex/README.md index 0f65e403..be50d92d 100644 --- a/python/llm/example/CPU/LlamaIndex/README.md +++ b/python/llm/example/CPU/LlamaIndex/README.md @@ -4,10 +4,6 @@ This folder contains examples showcasing how to use [**LlamaIndex**](https://github.com/run-llama/llama_index) with `ipex-llm`. > [**LlamaIndex**](https://github.com/run-llama/llama_index) is a data framework designed to improve large language models by providing tools for easier data ingestion, management, and application integration. -## Prerequisites - -Ensure `ipex-llm` is installed by following the [IPEX-LLM Installation Guide](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm#install) before proceeding with the examples provided here. - ## Retrieval-Augmented Generation (RAG) Example The RAG example ([rag.py](./rag.py)) is adapted from the [Official llama index RAG example](https://docs.llamaindex.ai/en/stable/examples/low_level/oss_ingestion_retrieval.html). This example builds a pipeline to ingest data (e.g. llama2 paper in pdf format) into a vector database (e.g. PostgreSQL), and then build a retrieval pipeline from that vector database. @@ -21,6 +17,10 @@ The RAG example ([rag.py](./rag.py)) is adapted from the [Official llama index R pip install llama-index-readers-file llama-index-vector-stores-postgres llama-index-embeddings-huggingface ``` +* **Install IPEX-LLM** +Ensure `ipex-llm` is installed by following the [IPEX-LLM Installation Guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Overview/install.html) before proceeding with the examples provided here. + + * **Database Setup (using PostgreSQL)**: * Installation: ```bash @@ -55,7 +55,7 @@ The RAG example ([rag.py](./rag.py)) is adapted from the [Official llama index R In the current directory, run the example with command: ```bash -python rag.py -m +python rag.py -m -t ``` **Additional Parameters for Configuration**: - `-m MODEL_PATH`: **Required**, path to the LLM model @@ -65,6 +65,7 @@ python rag.py -m - `-q QUESTION`: question you want to ask - `-d DATA`: path to source data used for retrieval (in pdf format) - `-n N_PREDICT`: max predict tokens +- `-t TOKENIZER_PATH`: **Required**, path to the tokenizer model ### Example Output diff --git a/python/llm/example/CPU/LlamaIndex/rag.py b/python/llm/example/CPU/LlamaIndex/rag.py index 9f26e55d..5759c624 100644 --- a/python/llm/example/CPU/LlamaIndex/rag.py +++ b/python/llm/example/CPU/LlamaIndex/rag.py @@ -163,11 +163,11 @@ def messages_to_prompt(messages): def main(args): embed_model = HuggingFaceEmbedding(model_name=args.embedding_model_path) - # Use custom LLM in IPEX-LLM - from ipex_llm.llamaindex.llms import BigdlLLM - llm = BigdlLLM( + # Use custom LLM in BigDL + from ipex_llm.llamaindex.llms import IpexLLM + llm = IpexLLM.from_model_id( model_name=args.model_path, - tokenizer_name=args.model_path, + tokenizer_name=args.tokenizer_path, context_window=512, max_new_tokens=args.n_predict, generate_kwargs={"temperature": 0.7, "do_sample": False}, @@ -244,8 +244,10 @@ if __name__ == "__main__": help="the password of the user in the database") parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en", help="the path to embedding model path") - parser.add_argument('-n','--n-predict', type=int, default=32, + parser.add_argument('-n','--n-predict', type=int, default=64, help='max number of predict tokens') + parser.add_argument('-t','--tokenizer-path',type=str,required=True, + help='the path to transformers tokenizer') args = parser.parse_args() main(args) \ No newline at end of file diff --git a/python/llm/example/GPU/LlamaIndex/README.md b/python/llm/example/GPU/LlamaIndex/README.md index 74ac0fd0..b01d4c47 100644 --- a/python/llm/example/GPU/LlamaIndex/README.md +++ b/python/llm/example/GPU/LlamaIndex/README.md @@ -16,9 +16,9 @@ The RAG example ([rag.py](./rag.py)) is adapted from the [Official llama index R ```bash pip install llama-index-readers-file llama-index-vector-stores-postgres llama-index-embeddings-huggingface ``` -* **Install Bigdl LLM** +* **Install IPEX-LLM** - Follow the instructions in [GPU Install Guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Overview/install_gpu.html) to install ipex-llm. + Follow the instructions in [GPU Install Guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Overview/install.html) to install ipex-llm. * **Database Setup (using PostgreSQL)**: * Linux @@ -145,7 +145,7 @@ There is no need to set further environment variables. In the current directory, run the example with command: ```bash -python rag.py -m +python rag.py -m -t ``` **Additional Parameters for Configuration**: - `-m MODEL_PATH`: **Required**, path to the LLM model @@ -155,6 +155,7 @@ python rag.py -m - `-q QUESTION`: question you want to ask - `-d DATA`: path to source data used for retrieval (in pdf format) - `-n N_PREDICT`: max predict tokens +- `-t TOKENIZER_PATH`: **Required**, path to the tokenizer model ### 5. Example Output diff --git a/python/llm/example/GPU/LlamaIndex/rag.py b/python/llm/example/GPU/LlamaIndex/rag.py index 87838715..fef32047 100644 --- a/python/llm/example/GPU/LlamaIndex/rag.py +++ b/python/llm/example/GPU/LlamaIndex/rag.py @@ -162,11 +162,11 @@ def messages_to_prompt(messages): def main(args): embed_model = HuggingFaceEmbedding(model_name=args.embedding_model_path) - # Use custom LLM in IPEX-LLM - from ipex_llm.llamaindex.llms import BigdlLLM - llm = BigdlLLM( + # Use custom LLM in BigDL + from ipex_llm.llamaindex.llms import IpexLLM + llm = IpexLLM.from_model_id( model_name=args.model_path, - tokenizer_name=args.model_path, + tokenizer_name=args.tokenizer_path, context_window=512, max_new_tokens=args.n_predict, generate_kwargs={"temperature": 0.7, "do_sample": False}, @@ -245,6 +245,8 @@ if __name__ == "__main__": help="the path to embedding model path") parser.add_argument('-n','--n-predict', type=int, default=32, help='max number of predict tokens') + parser.add_argument('-t','--tokenizer-path',type=str,required=True, + help='the path to transformers tokenizer') args = parser.parse_args() main(args) \ No newline at end of file diff --git a/python/llm/src/ipex_llm/llamaindex/llms/__init__.py b/python/llm/src/ipex_llm/llamaindex/llms/__init__.py index 54fc4aa5..a36e925d 100644 --- a/python/llm/src/ipex_llm/llamaindex/llms/__init__.py +++ b/python/llm/src/ipex_llm/llamaindex/llms/__init__.py @@ -26,9 +26,9 @@ from .bigdlllm import * from llama_index.core.base.llms.base import BaseLLM __all__ = [ - "BigdlLLM", + "IpexLLm", ] type_to_cls_dict: Dict[str, Type[BaseLLM]] = { - "BigdlLLM": BigdlLLM, + "IpexLLM": IpexLLM, } diff --git a/python/llm/src/ipex_llm/llamaindex/llms/bigdlllm.py b/python/llm/src/ipex_llm/llamaindex/llms/bigdlllm.py index 96550f6a..2a450ed1 100644 --- a/python/llm/src/ipex_llm/llamaindex/llms/bigdlllm.py +++ b/python/llm/src/ipex_llm/llamaindex/llms/bigdlllm.py @@ -70,6 +70,8 @@ from llama_index.core.llms.custom import CustomLLM from llama_index.core.base.llms.generic_utils import ( messages_to_prompt as generic_messages_to_prompt, + completion_response_to_chat_response, + stream_completion_response_to_chat_response ) from llama_index.core.prompts.base import PromptTemplate from llama_index.core.types import BaseOutputParser, PydanticProgramMode @@ -85,14 +87,14 @@ DEFAULT_HUGGINGFACE_MODEL = "meta-llama/Llama-2-7b-chat-hf" logger = logging.getLogger(__name__) -class BigdlLLM(CustomLLM): - """Wrapper around the BigDL-LLM +class IpexLLM(CustomLLM): + """Wrapper around the IPEX-LLM Example: .. code-block:: python - from ipex_llm.llamaindex.llms import BigdlLLM - llm = BigdlLLM(model_path="/path/to/llama/model") + from ipex_llm.llamaindex.llms import IpexLLM + llm = IpexLLM(model_path="/path/to/llama/model") """ model_name: str = Field( @@ -171,6 +173,10 @@ class BigdlLLM(CustomLLM): "`messages_to_prompt` that does so." ), ) + load_low_bit: bool = Field( + default=False, + description="The model is low_bit model or not" + ) _model: Any = PrivateAttr() _tokenizer: Any = PrivateAttr() @@ -198,9 +204,10 @@ class BigdlLLM(CustomLLM): completion_to_prompt: Optional[Callable[[str], str]]=None, pydantic_program_mode: PydanticProgramMode=PydanticProgramMode.DEFAULT, output_parser: Optional[BaseOutputParser] = None, + load_low_bit: bool = False ) -> None: """ - Construct BigdlLLM. + Construct Ipex-LLM. Args: @@ -229,6 +236,7 @@ class BigdlLLM(CustomLLM): completion_to_prompt: Function to convert messages to prompt. pydantic_program_mode: DEFAULT. output_parser: BaseOutputParser. + load_low_bit: Use low_bit checkpoint. Returns: None. @@ -238,15 +246,25 @@ class BigdlLLM(CustomLLM): if model: self._model = model else: - try: - self._model = AutoModelForCausalLM.from_pretrained( - model_name, load_in_4bit=True, use_cache=True, - trust_remote_code=True, **model_kwargs - ) - except: - from ipex_llm.transformers import AutoModel - self._model = AutoModel.from_pretrained(model_name, - load_in_4bit=True, **model_kwargs) + if not load_low_bit: + try: + self._model = AutoModelForCausalLM.from_pretrained( + model_name, load_in_4bit=True, use_cache=True, + trust_remote_code=True, **model_kwargs + ) + except: + from ipex_llm.transformers import AutoModel + self._model = AutoModel.from_pretrained(model_name, + load_in_4bit=True, + **model_kwargs) + else: + try: + self._model = AutoModelForCausalLM.load_low_bit( + model_name, use_cache=True, + trust_remote_code=True, **model_kwargs) + except: + from ipex_llm.transformers import AutoModel + self._model = AutoModel.load_low_bit(model_name, **model_kwargs) if 'xpu' in device_map: self._model = self._model.to(device_map) @@ -271,7 +289,6 @@ class BigdlLLM(CustomLLM): if tokenizer: self._tokenizer = tokenizer else: - print(f"load tokenizer: {tokenizer_name}") try: self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_kwargs) except: @@ -327,6 +344,157 @@ class BigdlLLM(CustomLLM): output_parser=output_parser, ) + @classmethod + def from_model_id( + cls, + context_window: int = DEFAULT_CONTEXT_WINDOW, + max_new_tokens: int = DEFAULT_NUM_OUTPUTS, + query_wrapper_prompt: Union[str, PromptTemplate]="{query_str}", + tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL, + model_name: str = DEFAULT_HUGGINGFACE_MODEL, + model: Optional[Any] = None, + tokenizer: Optional[Any] = None, + device_map: Optional[str] = "auto", + stopping_ids: Optional[List[int]] = None, + tokenizer_kwargs: Optional[dict] = None, + tokenizer_outputs_to_remove: Optional[list] = None, + model_kwargs: Optional[dict] = None, + generate_kwargs: Optional[dict] = None, + is_chat_model: Optional[bool] = False, + callback_manager: Optional[CallbackManager] = None, + system_prompt: str = "", + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]]=None, + completion_to_prompt: Optional[Callable[[str], str]]=None, + pydantic_program_mode: PydanticProgramMode=PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, + ): + """ + Construct IPEX-LLM from HuggingFace Model. + + Args: + + context_window: The maximum number of tokens available for input. + max_new_tokens: The maximum number of tokens to generate. + query_wrapper_prompt: The query wrapper prompt, containing the query placeholder. + Should contain a `{query_str}` placeholder. + tokenizer_name: The name of the tokenizer to use from HuggingFace. + Unused if `tokenizer` is passed in directly. + model_name: The model name to use from HuggingFace. + Unused if `model` is passed in directly. + model: The HuggingFace model. + tokenizer: The tokenizer. + device_map: The device_map to use. Defaults to 'auto'. + stopping_ids: The stopping ids to use. + Generation stops when these token IDs are predicted. + tokenizer_kwargs: The kwargs to pass to the tokenizer. + tokenizer_outputs_to_remove: The outputs to remove from the tokenizer. + Sometimes huggingface tokenizers return extra inputs that cause errors. + model_kwargs: The kwargs to pass to the model during initialization. + generate_kwargs: The kwargs to pass to the model during generation. + is_chat_model: Whether the model is `chat` + callback_manager: Callback manager. + system_prompt: The system prompt, containing any extra instructions or context. + messages_to_prompt: Function to convert messages to prompt. + completion_to_prompt: Function to convert messages to prompt. + pydantic_program_mode: DEFAULT. + output_parser: BaseOutputParser. + + Returns: + Ipex-LLM instance. + """ + return cls( + context_window=context_window, + max_new_tokens=max_new_tokens, + query_wrapper_prompt=query_wrapper_prompt, + tokenizer_name=tokenizer_name, + model_name=model_name, + device_map=device_map, + tokenizer_kwargs=tokenizer_kwargs, + model_kwargs=model_kwargs, + generate_kwargs=generate_kwargs, + is_chat_model=is_chat_model, + callback_manager=callback_manager, + system_prompt=system_prompt, + completion_to_prompt=completion_to_prompt, + messages_to_prompt=messages_to_prompt, + ) + + @classmethod + def from_model_id_low_bit( + cls, + context_window: int = DEFAULT_CONTEXT_WINDOW, + max_new_tokens: int = DEFAULT_NUM_OUTPUTS, + query_wrapper_prompt: Union[str, PromptTemplate]="{query_str}", + tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL, + model_name: str = DEFAULT_HUGGINGFACE_MODEL, + model: Optional[Any] = None, + tokenizer: Optional[Any] = None, + device_map: Optional[str] = "auto", + stopping_ids: Optional[List[int]] = None, + tokenizer_kwargs: Optional[dict] = None, + tokenizer_outputs_to_remove: Optional[list] = None, + model_kwargs: Optional[dict] = None, + generate_kwargs: Optional[dict] = None, + is_chat_model: Optional[bool] = False, + callback_manager: Optional[CallbackManager] = None, + system_prompt: str = "", + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]]=None, + completion_to_prompt: Optional[Callable[[str], str]]=None, + pydantic_program_mode: PydanticProgramMode=PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, + ): + """ + Construct IPEX-LLM from HuggingFace Model low-bit checkpoint. + + Args: + + context_window: The maximum number of tokens available for input. + max_new_tokens: The maximum number of tokens to generate. + query_wrapper_prompt: The query wrapper prompt, containing the query placeholder. + Should contain a `{query_str}` placeholder. + tokenizer_name: The name of the tokenizer to use from HuggingFace. + Unused if `tokenizer` is passed in directly. + model_name: The model name to use from HuggingFace. + Unused if `model` is passed in directly. + model: The HuggingFace model. + tokenizer: The tokenizer. + device_map: The device_map to use. Defaults to 'auto'. + stopping_ids: The stopping ids to use. + Generation stops when these token IDs are predicted. + tokenizer_kwargs: The kwargs to pass to the tokenizer. + tokenizer_outputs_to_remove: The outputs to remove from the tokenizer. + Sometimes huggingface tokenizers return extra inputs that cause errors. + model_kwargs: The kwargs to pass to the model during initialization. + generate_kwargs: The kwargs to pass to the model during generation. + is_chat_model: Whether the model is `chat` + callback_manager: Callback manager. + system_prompt: The system prompt, containing any extra instructions or context. + messages_to_prompt: Function to convert messages to prompt. + completion_to_prompt: Function to convert messages to prompt. + pydantic_program_mode: DEFAULT. + output_parser: BaseOutputParser. + + Returns: + Ipex-LLM instance. + """ + return cls( + context_window=context_window, + max_new_tokens=max_new_tokens, + query_wrapper_prompt=query_wrapper_prompt, + tokenizer_name=tokenizer_name, + model_name=model_name, + device_map=device_map, + tokenizer_kwargs=tokenizer_kwargs, + model_kwargs=model_kwargs, + generate_kwargs=generate_kwargs, + is_chat_model=is_chat_model, + callback_manager=callback_manager, + system_prompt=system_prompt, + completion_to_prompt=completion_to_prompt, + messages_to_prompt=messages_to_prompt, + load_low_bit=True, + ) + @classmethod def class_name(cls) -> str: """ @@ -338,7 +506,7 @@ class BigdlLLM(CustomLLM): Returns: Str of class name. """ - return "BigDL_LLM" + return "IpexLLM" @property def metadata(self) -> LLMMetadata: @@ -431,7 +599,7 @@ class BigdlLLM(CustomLLM): Returns: CompletionReponse after generation. """ - from transformers import TextStreamer + from transformers import TextIteratorStreamer full_prompt = prompt if not formatted: if self.query_wrapper_prompt: @@ -446,9 +614,9 @@ class BigdlLLM(CustomLLM): if key in input_ids: input_ids.pop(key, None) - streamer = TextStreamer(self._tokenizer, skip_prompt=True, skip_special_tokens=True) + streamer = TextIteratorStreamer(self._tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( - input_ids, + input_ids=input_ids, streamer=streamer, max_new_tokens=self.max_new_tokens, stopping_criteria=self._stopping_criteria, @@ -465,3 +633,17 @@ class BigdlLLM(CustomLLM): yield CompletionResponse(text=text, delta=x) return gen() + + @llm_chat_callback() + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + prompt = self.messages_to_prompt(messages) + completion_response = self.complete(prompt, formatted=True, **kwargs) + return completion_response_to_chat_response(completion_response) + + @llm_chat_callback() + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + prompt = self.messages_to_prompt(messages) + completion_response = self.stream_complete(prompt, formatted=True, **kwargs) + return stream_completion_response_to_chat_response(completion_response)