Fix llamaindex AutoTokenizer bug (#10345)
* fix tokenizer * fix AutoTokenizer bug * modify code style
This commit is contained in:
parent
2a10b53d73
commit
9026c08633
5 changed files with 31 additions and 8 deletions
|
|
@ -64,6 +64,7 @@ python rag.py -m <path_to_model>
|
|||
- `-p PASSWORD`: password in the PostgreSQL database
|
||||
- `-q QUESTION`: question you want to ask
|
||||
- `-d DATA`: path to source data used for retrieval (in pdf format)
|
||||
- `-n N_PREDICT`: max predict tokens
|
||||
|
||||
### Example Output
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import torch
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from sqlalchemy import make_url
|
||||
from llama_index.vector_stores.postgres import PGVectorStore
|
||||
|
|
@ -167,7 +169,7 @@ def main(args):
|
|||
model_name=args.model_path,
|
||||
tokenizer_name=args.model_path,
|
||||
context_window=512,
|
||||
max_new_tokens=32,
|
||||
max_new_tokens=args.n_predict,
|
||||
generate_kwargs={"temperature": 0.7, "do_sample": False},
|
||||
model_kwargs={},
|
||||
messages_to_prompt=messages_to_prompt,
|
||||
|
|
@ -242,6 +244,8 @@ if __name__ == "__main__":
|
|||
help="the password of the user in the database")
|
||||
parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en",
|
||||
help="the path to embedding model path")
|
||||
parser.add_argument('-n','--n-predict', type=int, default=32,
|
||||
help='max number of predict tokens')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
@ -154,6 +154,7 @@ python rag.py -m <path_to_model>
|
|||
- `-p PASSWORD`: password in the PostgreSQL database
|
||||
- `-q QUESTION`: question you want to ask
|
||||
- `-d DATA`: path to source data used for retrieval (in pdf format)
|
||||
- `-n N_PREDICT`: max predict tokens
|
||||
|
||||
### 5. Example Output
|
||||
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ def main(args):
|
|||
model_name=args.model_path,
|
||||
tokenizer_name=args.model_path,
|
||||
context_window=512,
|
||||
max_new_tokens=32,
|
||||
max_new_tokens=args.n_predict,
|
||||
generate_kwargs={"temperature": 0.7, "do_sample": False},
|
||||
model_kwargs={},
|
||||
messages_to_prompt=messages_to_prompt,
|
||||
|
|
@ -243,6 +243,8 @@ if __name__ == "__main__":
|
|||
help="the password of the user in the database")
|
||||
parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en",
|
||||
help="the path to embedding model path")
|
||||
parser.add_argument('-n','--n-predict', type=int, default=32,
|
||||
help='max number of predict tokens')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
@ -235,9 +235,18 @@ class BigdlLLM(CustomLLM):
|
|||
"""
|
||||
model_kwargs = model_kwargs or {}
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
self._model = model or AutoModelForCausalLM.from_pretrained(
|
||||
model_name, load_in_4bit=True, **model_kwargs
|
||||
)
|
||||
if model:
|
||||
self._model = model
|
||||
else:
|
||||
try:
|
||||
self._model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, load_in_4bit=True, use_cache=True,
|
||||
trust_remote_code=True, **model_kwargs
|
||||
)
|
||||
except:
|
||||
from bigdl.llm.transformers import AutoModel
|
||||
self._model = AutoModel.from_pretrained(model_name,
|
||||
load_in_4bit=True, **model_kwargs)
|
||||
|
||||
if 'xpu' in device_map:
|
||||
self._model = self._model.to(device_map)
|
||||
|
|
@ -259,9 +268,15 @@ class BigdlLLM(CustomLLM):
|
|||
if "max_length" not in tokenizer_kwargs:
|
||||
tokenizer_kwargs["max_length"] = context_window
|
||||
|
||||
self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(
|
||||
tokenizer_name, **tokenizer_kwargs
|
||||
)
|
||||
if tokenizer:
|
||||
self._tokenizer = tokenizer
|
||||
else:
|
||||
print(f"load tokenizer: {tokenizer_name}")
|
||||
try:
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_kwargs)
|
||||
except:
|
||||
self._tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name,
|
||||
trust_remote_code=True)
|
||||
|
||||
if tokenizer_name != model_name:
|
||||
logger.warning(
|
||||
|
|
|
|||
Loading…
Reference in a new issue