Fix llamaindex AutoTokenizer bug (#10345)
* fix tokenizer * fix AutoTokenizer bug * modify code style
This commit is contained in:
		
							parent
							
								
									2a10b53d73
								
							
						
					
					
						commit
						9026c08633
					
				
					 5 changed files with 31 additions and 8 deletions
				
			
		| 
						 | 
				
			
			@ -64,6 +64,7 @@ python rag.py -m <path_to_model>
 | 
			
		|||
- `-p PASSWORD`: password in the PostgreSQL database
 | 
			
		||||
- `-q QUESTION`: question you want to ask
 | 
			
		||||
- `-d DATA`: path to source data used for retrieval (in pdf format)
 | 
			
		||||
- `-n N_PREDICT`: max predict tokens
 | 
			
		||||
 | 
			
		||||
### Example Output
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,6 +14,8 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
 | 
			
		||||
from sqlalchemy import make_url
 | 
			
		||||
from llama_index.vector_stores.postgres import PGVectorStore
 | 
			
		||||
| 
						 | 
				
			
			@ -167,7 +169,7 @@ def main(args):
 | 
			
		|||
        model_name=args.model_path,
 | 
			
		||||
        tokenizer_name=args.model_path,
 | 
			
		||||
        context_window=512,
 | 
			
		||||
        max_new_tokens=32,
 | 
			
		||||
        max_new_tokens=args.n_predict,
 | 
			
		||||
        generate_kwargs={"temperature": 0.7, "do_sample": False},
 | 
			
		||||
        model_kwargs={},
 | 
			
		||||
        messages_to_prompt=messages_to_prompt,
 | 
			
		||||
| 
						 | 
				
			
			@ -242,6 +244,8 @@ if __name__ == "__main__":
 | 
			
		|||
                        help="the password of the user in the database")
 | 
			
		||||
    parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en",
 | 
			
		||||
                        help="the path to embedding model path")
 | 
			
		||||
    parser.add_argument('-n','--n-predict', type=int, default=32,
 | 
			
		||||
                        help='max number of predict tokens')
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    
 | 
			
		||||
    main(args)
 | 
			
		||||
| 
						 | 
				
			
			@ -154,6 +154,7 @@ python rag.py -m <path_to_model>
 | 
			
		|||
- `-p PASSWORD`: password in the PostgreSQL database
 | 
			
		||||
- `-q QUESTION`: question you want to ask
 | 
			
		||||
- `-d DATA`: path to source data used for retrieval (in pdf format)
 | 
			
		||||
- `-n N_PREDICT`: max predict tokens
 | 
			
		||||
 | 
			
		||||
### 5. Example Output
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -168,7 +168,7 @@ def main(args):
 | 
			
		|||
        model_name=args.model_path,
 | 
			
		||||
        tokenizer_name=args.model_path,
 | 
			
		||||
        context_window=512,
 | 
			
		||||
        max_new_tokens=32,
 | 
			
		||||
        max_new_tokens=args.n_predict,
 | 
			
		||||
        generate_kwargs={"temperature": 0.7, "do_sample": False},
 | 
			
		||||
        model_kwargs={},
 | 
			
		||||
        messages_to_prompt=messages_to_prompt,
 | 
			
		||||
| 
						 | 
				
			
			@ -243,6 +243,8 @@ if __name__ == "__main__":
 | 
			
		|||
                        help="the password of the user in the database")
 | 
			
		||||
    parser.add_argument('-e','--embedding-model-path',default="BAAI/bge-small-en",
 | 
			
		||||
                        help="the path to embedding model path")
 | 
			
		||||
    parser.add_argument('-n','--n-predict', type=int, default=32,
 | 
			
		||||
                        help='max number of predict tokens')
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    
 | 
			
		||||
    main(args)
 | 
			
		||||
| 
						 | 
				
			
			@ -235,9 +235,18 @@ class BigdlLLM(CustomLLM):
 | 
			
		|||
        """
 | 
			
		||||
        model_kwargs = model_kwargs or {}
 | 
			
		||||
        from bigdl.llm.transformers import AutoModelForCausalLM
 | 
			
		||||
        self._model = model or AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
            model_name, load_in_4bit=True, **model_kwargs
 | 
			
		||||
        )
 | 
			
		||||
        if model:
 | 
			
		||||
            self._model = model
 | 
			
		||||
        else:
 | 
			
		||||
            try:
 | 
			
		||||
                self._model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
                    model_name, load_in_4bit=True, use_cache=True,
 | 
			
		||||
                    trust_remote_code=True, **model_kwargs
 | 
			
		||||
                )
 | 
			
		||||
            except:
 | 
			
		||||
                from bigdl.llm.transformers import AutoModel
 | 
			
		||||
                self._model = AutoModel.from_pretrained(model_name,
 | 
			
		||||
                                                        load_in_4bit=True, **model_kwargs)
 | 
			
		||||
 | 
			
		||||
        if 'xpu' in device_map:
 | 
			
		||||
            self._model = self._model.to(device_map)
 | 
			
		||||
| 
						 | 
				
			
			@ -259,9 +268,15 @@ class BigdlLLM(CustomLLM):
 | 
			
		|||
        if "max_length" not in tokenizer_kwargs:
 | 
			
		||||
            tokenizer_kwargs["max_length"] = context_window
 | 
			
		||||
 | 
			
		||||
        self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(
 | 
			
		||||
            tokenizer_name, **tokenizer_kwargs
 | 
			
		||||
        )
 | 
			
		||||
        if tokenizer:
 | 
			
		||||
            self._tokenizer = tokenizer
 | 
			
		||||
        else:
 | 
			
		||||
            print(f"load tokenizer: {tokenizer_name}")
 | 
			
		||||
            try:
 | 
			
		||||
                self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_kwargs)
 | 
			
		||||
            except:
 | 
			
		||||
                self._tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name,
 | 
			
		||||
                                                                 trust_remote_code=True)
 | 
			
		||||
 | 
			
		||||
        if tokenizer_name != model_name:
 | 
			
		||||
            logger.warning(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue