Fix llamaindex AutoTokenizer bug (#10345)
* fix tokenizer * fix AutoTokenizer bug * modify code style
This commit is contained in:
parent
2a10b53d73
commit
9026c08633
5 changed files with 31 additions and 8 deletions
|
|
@ -64,6 +64,7 @@ python rag.py -m <path_to_model>
|
||||||
- `-p PASSWORD`: password in the PostgreSQL database
|
- `-p PASSWORD`: password in the PostgreSQL database
|
||||||
- `-q QUESTION`: question you want to ask
|
- `-q QUESTION`: question you want to ask
|
||||||
- `-d DATA`: path to source data used for retrieval (in pdf format)
|
- `-d DATA`: path to source data used for retrieval (in pdf format)
|
||||||
|
- `-n N_PREDICT`: max predict tokens
|
||||||
|
|
||||||
### Example Output
|
### Example Output
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
from sqlalchemy import make_url
|
from sqlalchemy import make_url
|
||||||
from llama_index.vector_stores.postgres import PGVectorStore
|
from llama_index.vector_stores.postgres import PGVectorStore
|
||||||
|
|
@ -167,7 +169,7 @@ def main(args):
|
||||||
model_name=args.model_path,
|
model_name=args.model_path,
|
||||||
tokenizer_name=args.model_path,
|
tokenizer_name=args.model_path,
|
||||||
context_window=512,
|
context_window=512,
|
||||||
max_new_tokens=32,
|
max_new_tokens=args.n_predict,
|
||||||
generate_kwargs={"temperature": 0.7, "do_sample": False},
|
generate_kwargs={"temperature": 0.7, "do_sample": False},
|
||||||
model_kwargs={},
|
model_kwargs={},
|
||||||
messages_to_prompt=messages_to_prompt,
|
messages_to_prompt=messages_to_prompt,
|
||||||
|
|
@ -242,6 +244,8 @@ if __name__ == "__main__":
|
||||||
help="the password of the user in the database")
|
help="the password of the user in the database")
|
||||||
parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en",
|
parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en",
|
||||||
help="the path to embedding model path")
|
help="the path to embedding model path")
|
||||||
|
parser.add_argument('-n','--n-predict', type=int, default=32,
|
||||||
|
help='max number of predict tokens')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
|
@ -154,6 +154,7 @@ python rag.py -m <path_to_model>
|
||||||
- `-p PASSWORD`: password in the PostgreSQL database
|
- `-p PASSWORD`: password in the PostgreSQL database
|
||||||
- `-q QUESTION`: question you want to ask
|
- `-q QUESTION`: question you want to ask
|
||||||
- `-d DATA`: path to source data used for retrieval (in pdf format)
|
- `-d DATA`: path to source data used for retrieval (in pdf format)
|
||||||
|
- `-n N_PREDICT`: max predict tokens
|
||||||
|
|
||||||
### 5. Example Output
|
### 5. Example Output
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -168,7 +168,7 @@ def main(args):
|
||||||
model_name=args.model_path,
|
model_name=args.model_path,
|
||||||
tokenizer_name=args.model_path,
|
tokenizer_name=args.model_path,
|
||||||
context_window=512,
|
context_window=512,
|
||||||
max_new_tokens=32,
|
max_new_tokens=args.n_predict,
|
||||||
generate_kwargs={"temperature": 0.7, "do_sample": False},
|
generate_kwargs={"temperature": 0.7, "do_sample": False},
|
||||||
model_kwargs={},
|
model_kwargs={},
|
||||||
messages_to_prompt=messages_to_prompt,
|
messages_to_prompt=messages_to_prompt,
|
||||||
|
|
@ -243,6 +243,8 @@ if __name__ == "__main__":
|
||||||
help="the password of the user in the database")
|
help="the password of the user in the database")
|
||||||
parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en",
|
parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en",
|
||||||
help="the path to embedding model path")
|
help="the path to embedding model path")
|
||||||
|
parser.add_argument('-n','--n-predict', type=int, default=32,
|
||||||
|
help='max number of predict tokens')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
|
@ -235,9 +235,18 @@ class BigdlLLM(CustomLLM):
|
||||||
"""
|
"""
|
||||||
model_kwargs = model_kwargs or {}
|
model_kwargs = model_kwargs or {}
|
||||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||||
self._model = model or AutoModelForCausalLM.from_pretrained(
|
if model:
|
||||||
model_name, load_in_4bit=True, **model_kwargs
|
self._model = model
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self._model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name, load_in_4bit=True, use_cache=True,
|
||||||
|
trust_remote_code=True, **model_kwargs
|
||||||
)
|
)
|
||||||
|
except:
|
||||||
|
from bigdl.llm.transformers import AutoModel
|
||||||
|
self._model = AutoModel.from_pretrained(model_name,
|
||||||
|
load_in_4bit=True, **model_kwargs)
|
||||||
|
|
||||||
if 'xpu' in device_map:
|
if 'xpu' in device_map:
|
||||||
self._model = self._model.to(device_map)
|
self._model = self._model.to(device_map)
|
||||||
|
|
@ -259,9 +268,15 @@ class BigdlLLM(CustomLLM):
|
||||||
if "max_length" not in tokenizer_kwargs:
|
if "max_length" not in tokenizer_kwargs:
|
||||||
tokenizer_kwargs["max_length"] = context_window
|
tokenizer_kwargs["max_length"] = context_window
|
||||||
|
|
||||||
self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(
|
if tokenizer:
|
||||||
tokenizer_name, **tokenizer_kwargs
|
self._tokenizer = tokenizer
|
||||||
)
|
else:
|
||||||
|
print(f"load tokenizer: {tokenizer_name}")
|
||||||
|
try:
|
||||||
|
self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_kwargs)
|
||||||
|
except:
|
||||||
|
self._tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name,
|
||||||
|
trust_remote_code=True)
|
||||||
|
|
||||||
if tokenizer_name != model_name:
|
if tokenizer_name != model_name:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue