diff --git a/python/llm/example/CPU/LlamaIndex/README.md b/python/llm/example/CPU/LlamaIndex/README.md index 6427ace6..ce731086 100644 --- a/python/llm/example/CPU/LlamaIndex/README.md +++ b/python/llm/example/CPU/LlamaIndex/README.md @@ -64,6 +64,7 @@ python rag.py -m - `-p PASSWORD`: password in the PostgreSQL database - `-q QUESTION`: question you want to ask - `-d DATA`: path to source data used for retrieval (in pdf format) +- `-n N_PREDICT`: max predict tokens ### Example Output diff --git a/python/llm/example/CPU/LlamaIndex/rag.py b/python/llm/example/CPU/LlamaIndex/rag.py index cf11b5ad..9fd81ca1 100644 --- a/python/llm/example/CPU/LlamaIndex/rag.py +++ b/python/llm/example/CPU/LlamaIndex/rag.py @@ -14,6 +14,8 @@ # limitations under the License. # + +import torch from llama_index.embeddings.huggingface import HuggingFaceEmbedding from sqlalchemy import make_url from llama_index.vector_stores.postgres import PGVectorStore @@ -167,7 +169,7 @@ def main(args): model_name=args.model_path, tokenizer_name=args.model_path, context_window=512, - max_new_tokens=32, + max_new_tokens=args.n_predict, generate_kwargs={"temperature": 0.7, "do_sample": False}, model_kwargs={}, messages_to_prompt=messages_to_prompt, @@ -242,6 +244,8 @@ if __name__ == "__main__": help="the password of the user in the database") parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en", help="the path to embedding model path") + parser.add_argument('-n','--n-predict', type=int, default=32, + help='max number of predict tokens') args = parser.parse_args() main(args) \ No newline at end of file diff --git a/python/llm/example/GPU/LlamaIndex/README.md b/python/llm/example/GPU/LlamaIndex/README.md index bc6b3322..5eb836e4 100644 --- a/python/llm/example/GPU/LlamaIndex/README.md +++ b/python/llm/example/GPU/LlamaIndex/README.md @@ -154,6 +154,7 @@ python rag.py -m - `-p PASSWORD`: password in the PostgreSQL database - `-q QUESTION`: question you want to ask - `-d DATA`: path to source data used for retrieval (in pdf format) +- `-n N_PREDICT`: max predict tokens ### 5. Example Output diff --git a/python/llm/example/GPU/LlamaIndex/rag.py b/python/llm/example/GPU/LlamaIndex/rag.py index c105370d..7fb1146e 100644 --- a/python/llm/example/GPU/LlamaIndex/rag.py +++ b/python/llm/example/GPU/LlamaIndex/rag.py @@ -168,7 +168,7 @@ def main(args): model_name=args.model_path, tokenizer_name=args.model_path, context_window=512, - max_new_tokens=32, + max_new_tokens=args.n_predict, generate_kwargs={"temperature": 0.7, "do_sample": False}, model_kwargs={}, messages_to_prompt=messages_to_prompt, @@ -243,6 +243,8 @@ if __name__ == "__main__": help="the password of the user in the database") parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en", help="the path to embedding model path") + parser.add_argument('-n','--n-predict', type=int, default=32, + help='max number of predict tokens') args = parser.parse_args() main(args) \ No newline at end of file diff --git a/python/llm/src/bigdl/llm/llamaindex/llms/bigdlllm.py b/python/llm/src/bigdl/llm/llamaindex/llms/bigdlllm.py index c157acdf..d6f61854 100644 --- a/python/llm/src/bigdl/llm/llamaindex/llms/bigdlllm.py +++ b/python/llm/src/bigdl/llm/llamaindex/llms/bigdlllm.py @@ -235,9 +235,18 @@ class BigdlLLM(CustomLLM): """ model_kwargs = model_kwargs or {} from bigdl.llm.transformers import AutoModelForCausalLM - self._model = model or AutoModelForCausalLM.from_pretrained( - model_name, load_in_4bit=True, **model_kwargs - ) + if model: + self._model = model + else: + try: + self._model = AutoModelForCausalLM.from_pretrained( + model_name, load_in_4bit=True, use_cache=True, + trust_remote_code=True, **model_kwargs + ) + except: + from bigdl.llm.transformers import AutoModel + self._model = AutoModel.from_pretrained(model_name, + load_in_4bit=True, **model_kwargs) if 'xpu' in device_map: self._model = self._model.to(device_map) @@ -259,9 +268,15 @@ class BigdlLLM(CustomLLM): if "max_length" not in tokenizer_kwargs: tokenizer_kwargs["max_length"] = context_window - self._tokenizer = tokenizer or AutoTokenizer.from_pretrained( - tokenizer_name, **tokenizer_kwargs - ) + if tokenizer: + self._tokenizer = tokenizer + else: + print(f"load tokenizer: {tokenizer_name}") + try: + self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_kwargs) + except: + self._tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, + trust_remote_code=True) if tokenizer_name != model_name: logger.warning(