diff --git a/python/llm/README.md b/python/llm/README.md index 1b3c0dc4..f9adec8f 100644 --- a/python/llm/README.md +++ b/python/llm/README.md @@ -75,40 +75,46 @@ llm-cli -x llama -h ``` #### Transformers like API -Users could load converted model or even the unconverted huggingface model directly by `AutoModelForCausalLM.from_pretrained`. +You can also load the converted model using `BigdlForCausalLM` with a transformer like API, +```python +from bigdl.llm.transformers import BigdlForCausalLM +llm = BigdlForCausalLM.from_pretrained("/path/to/llama-7b-int4/bigdl-llm-xxx.bin", + model_family="llama") +prompt="What is AI?" +``` +and simply do inference end-to-end like +```python +output = llm(prompt, max_tokens=32) +``` +If you need to seperate the tokenization and generation, you can also do inference like +```python +tokens_id = llm.tokenize(prompt) +output_tokens_id = llm.generate(tokens_id, max_new_tokens=32) +output = llm.batch_decode(output_tokens_id) +``` + + +Alternatively, you can load huggingface model directly using `AutoModelForCausalLM.from_pretrained`. ```python -from bigdl.llm.ggml.transformers import AutoModelForCausalLM +from bigdl.llm.transformers import AutoModelForCausalLM -# option 1: load converted model -llm = AutoModelForCausalLM.from_pretrained("/path/to/llama-7b-int4/bigdl-llm-xxx.bin", - model_family="llama") - -# option 2: load huggingface checkpoint +# option 1: load huggingface checkpoint llm = AutoModelForCausalLM.from_pretrained("/path/to/llama-7b-hf/", model_family="llama") -# option 3: load from huggingface hub repo +# option 2: load from huggingface hub repo llm = AutoModelForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", model_family="llama") ``` -Users could use llm to do the inference. Apart from end-to-end fast forward, we also support split the tokenization and model inference in our API. - +You can then use the the model the same way as you use transformers. ```python -# end-to-end fast forward w/o spliting the tokenization and model inferencing -result = llm("what is ai") - # Use transformers tokenizer tokenizer = AutoTokenizer.from_pretrained(model_ckpt) tokens = tokenizer("what is ai").input_ids tokens_id = llm.generate(tokens, max_new_tokens=32) tokenizer.batch_decode(tokens_id) - -# Use bigdl-llm tokenizer -tokens = llm.tokenize("what is ai") -tokens_id = llm.generate(tokens, max_new_tokens=32) -decoded = llm.batch_decode(tokens_id) ``` #### llama-cpp-python like API diff --git a/python/llm/example/transformers/int4_pipeline.py b/python/llm/example/transformers/int4_pipeline.py index d8ae9238..7499a725 100644 --- a/python/llm/example/transformers/int4_pipeline.py +++ b/python/llm/example/transformers/int4_pipeline.py @@ -18,54 +18,49 @@ import time import argparse -def convert_and_load(repo_id_or_model_path, model_family, n_threads): +def convert(repo_id_or_model_path, model_family, tmp_path): + from bigdl.llm import llm_convert + original_llm_path = repo_id_or_model_path + bigdl_llm_path = llm_convert( + model=original_llm_path, + outfile='./', + outtype='int4', + tmp_path=tmp_path, + model_family=model_family) - from bigdl.llm.ggml.transformers import AutoModelForCausalLM + return bigdl_llm_path - # here you may input the HuggingFace repo id directly as the value of `pretrained_model_name_or_path`. - # This will allow the pre-trained model to be downloaded directly from the HuggingFace repository. - # The downloaded model will then be converted to binary format with int4 dtype weights, - # and saved into the cache_dir folder. - # - # if you already have the pre-trained model downloaded, you can provide the path to - # the downloaded folder as the value of `pretrained_model_name_or_path`` - llm = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=repo_id_or_model_path, +def load(model_path, model_family, n_threads): + from bigdl.llm.transformers import BigdlForCausalLM + llm = BigdlForCausalLM.from_pretrained( + pretrained_model_name_or_path=model_path, model_family=model_family, - dtype='int4', - cache_dir='./', n_threads=n_threads) - # if you want to explicitly convert the pre-trained model, you can use the `llm_convert` API - # to convert the downloaded Huggungface checkpoint first, - # and then load the binary checkpoint directly. - # - # from bigdl.llm import llm_convert - # - # model_path = repo_id_or_model_path - # output_ckpt_path = llm_convert( - # model=model_path, - # outfile='./', - # outtype='int4', - # model_family=model_family) - # - # llm = AutoModelForCausalLM.from_pretrained( - # pretrained_model_name_or_path=output_ckpt_path, - # model_family=model_family, - # n_threads=n_threads) - return llm def inference(llm, repo_id_or_model_path, model_family, prompt): if model_family in ['llama', 'gptneox']: - # Option 1: Use HuggingFace transformers tokenizer + # ------ Option 1: Use bigdl-llm based tokenizer + print('-'*20, ' bigdl-llm based tokenizer ', '-'*20) + st = time.time() + + # please note that the prompt here can either be a string or a list of string + tokens_id = llm.tokenize(prompt) + output_tokens_id = llm.generate(tokens_id, max_new_tokens=32) + output = llm.batch_decode(output_tokens_id) + + print(f'Inference time: {time.time()-st} s') + print(f'Output:\n{output}') + + # ------- Option 2: Use HuggingFace transformers tokenizer print('-'*20, ' HuggingFace transformers tokenizer ', '-'*20) print('Please note that the loading of HuggingFace transformers tokenizer may take some time.\n') # here is only a workaround for default example model 'decapoda-research/llama-7b-hf' in LLaMA family, # due to its out-of-date 'tokenizer_class' defined in its tokenizer_config.json. - # + # for most cases, you could use `AutoTokenizer`. if model_family == 'llama': from transformers import LlamaTokenizer @@ -84,17 +79,6 @@ def inference(llm, repo_id_or_model_path, model_family, prompt): print(f'Inference time: {time.time()-st} s') print(f'Output:\n{output}') - # Option 2: Use bigdl-llm based tokenizer - print('-'*20, ' bigdl-llm based tokenizer ', '-'*20) - st = time.time() - - # please note that the prompt here can either be a string or a list of string - tokens_id = llm.tokenize(prompt) - output_tokens_id = llm.generate(tokens_id, max_new_tokens=32) - output = llm.batch_decode(output_tokens_id) - - print(f'Inference time: {time.time()-st} s') - print(f'Output:\n{output}') if model_family in ['llama', 'gptneox', 'bloom']: # Option 3: fast forward @@ -121,6 +105,8 @@ def main(): ', or the path to the huggingface checkpoint folder') parser.add_argument('--prompt', type=str, default='Q: What is CPU? A:', help='Prompt to infer') + parser.add_argument('--tmp-path', type=str, default='/tmp', + help='path to store intermediate model during the conversion process') args = parser.parse_args() repo_id_or_model_path = args.repo_id_or_model_path @@ -132,12 +118,18 @@ def main(): elif args.model_family == 'bloom': repo_id_or_model_path = 'bigscience/bloomz-7b1' - # Step 1: convert and load int4 model - llm = convert_and_load(repo_id_or_model_path=repo_id_or_model_path, - model_family=args.model_family, - n_threads=args.thread_num) + # Step 1: convert original model to BigDL llm model + bigdl_llm_path = convert(repo_id_or_model_path=repo_id_or_model_path, + model_family=args.model_family, + tmp_path=args.tmp_path) + + + # Step 2: load int4 model + llm = load(model_path=bigdl_llm_path, + model_family=args.model_family, + n_threads=args.thread_num) - # Step 2: conduct inference + # Step 3: inference inference(llm=llm, repo_id_or_model_path=repo_id_or_model_path, model_family=args.model_family, diff --git a/python/llm/src/bigdl/llm/ggml/transformers/__init__.py b/python/llm/src/bigdl/llm/ggml/transformers/__init__.py deleted file mode 100644 index 51a9e37d..00000000 --- a/python/llm/src/bigdl/llm/ggml/transformers/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# -# Copyright 2016 The BigDL Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# This would makes sure Python is aware there is more than one sub-package within bigdl, -# physically located elsewhere. -# Otherwise there would be module not found error in non-pip's setting as Python would -# only search the first bigdl package and end up finding only one sub-package. - -from .model import AutoModelForCausalLM diff --git a/python/llm/src/bigdl/llm/ggml/transformers/model.py b/python/llm/src/bigdl/llm/ggml/transformers/model.py deleted file mode 100644 index 86999bad..00000000 --- a/python/llm/src/bigdl/llm/ggml/transformers/model.py +++ /dev/null @@ -1,128 +0,0 @@ -# -# Copyright 2016 The BigDL Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# This would makes sure Python is aware there is more than one sub-package within bigdl, -# physically located elsewhere. -# Otherwise there would be module not found error in non-pip's setting as Python would -# only search the first bigdl package and end up finding only one sub-package. - -import os -import traceback -from bigdl.llm.utils.common import invalidInputError - - -class AutoModelForCausalLM: - """ - A generic model class that mimics the behavior of - ``transformers.AutoModelForCausalLM.from_pretrained`` API - """ - - @classmethod - def from_pretrained(cls, - pretrained_model_name_or_path: str, - model_format: str = 'pth', - model_family: str = 'llama', - dtype: str = 'int4', - cache_dir: str = './', - tmp_path: str = None, - **kwargs): - """ - :param pretrained_model_name_or_path: We support 3 kinds of pretrained model checkpoint - - 1. Path to directory for Hugging Face checkpoint that are directly pulled from - Hugging Face hub. - - If ``model_format='pth'``, the folder should contain: weight bin, tokenizer - config, tokenizer.model (required for llama) and added_tokens.json (if applied). - For lora fine tuned model, the path should be pointed to a merged weight. - - If ``model_format='gptq'``, the folder should be be a Hugging Face checkpoint - in GPTQ format, which contains weights in pytorch's .pt format, - and ``tokenizer.model``. - - 2. Path for converted BigDL-LLM optimized ggml binary checkpoint. - The checkpoint should be converted by ``bigdl.llm.llm_convert``. - 3. A str for Hugging Face hub repo id. - - :param model_format: Specify the model format to be converted. ``pth`` is for - PyTorch model checkpoint from Hugging Face. ``gptq`` is for GPTQ format - model from Hugging Face. - :param model_family: The model family of the pretrained checkpoint. - Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"`` and ``"starcoder"``. - :param dtype: Which quantized precision will be converted. - Now only `int4` and `int8` are supported, and `int8` only works for `llama` - , `gptneox` and `starcoder`. - :param cache_dir: (optional) This parameter will only be used when - ``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id. - It indicates the saving path for the converted low precision model. - :param tmp_path: (optional) Which path to store the intermediate fp16 model during the - conversion process. Default to `None` so that intermediate model will not be saved. - :param **kwargs: keyword arguments which will be passed to the model instance - - :return: a model instance - """ - invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder'], - "Now we only support model family: 'llama', 'gptneox', 'bloom'," - " 'starcoder', '{}' is not in the list.".format(model_family)) - invalidInputError(dtype.lower() in ['int4', 'int8'], - "Now we only support int4 and int8 as date type for weight") - - # check whether pretrained_model_name_or_path exists. - # if not, it is likely that the user wants to pass in the repo id. - if not os.path.exists(pretrained_model_name_or_path): - try: - # download from Hugging Face based on repo id - from huggingface_hub import snapshot_download - pretrained_model_name_or_path = snapshot_download( - repo_id=pretrained_model_name_or_path) - except Exception as e: - traceback.print_exc() - # if downloading fails, it could be the case that repo id is invalid, - # or the user pass in the wrong path for checkpoint - invalidInputError(False, - "Downloadng from Hugging Face repo id {} failed. " - "Please input valid Hugging Face hub repo id, " - "or provide the valid path to Hugging Face / " - "BigDL-LLM optimized ggml binary checkpoint, " - "for pretrained_model_name_or_path" - .format(pretrained_model_name_or_path)) - - ggml_model_path = pretrained_model_name_or_path - # check whether pretrained_model_name_or_path is a file. - # if not, it is likely that pretrained_model_name_or_path - # points to a Hugging Face checkpoint - if not os.path.isfile(pretrained_model_name_or_path): - # Hugging Face checkpoint - from bigdl.llm import llm_convert - ggml_model_path = llm_convert(model=pretrained_model_name_or_path, - outfile=cache_dir, - model_family=model_family, - outtype=dtype, - model_format=model_format, - tmp_path=tmp_path) - - if model_family == 'llama': - from bigdl.llm.ggml.model.llama import Llama - return Llama(model_path=ggml_model_path, **kwargs) - elif model_family == 'gptneox': - from bigdl.llm.ggml.model.gptneox import Gptneox - return Gptneox(model_path=ggml_model_path, **kwargs) - elif model_family == 'bloom': - from bigdl.llm.ggml.model.bloom import Bloom - return Bloom(model_path=ggml_model_path, **kwargs) - elif model_family == 'starcoder': - from bigdl.llm.ggml.model.starcoder import Starcoder - return Starcoder(model_path=ggml_model_path, **kwargs) diff --git a/python/llm/src/bigdl/llm/transformers/__init__.py b/python/llm/src/bigdl/llm/transformers/__init__.py index c6713d32..eaef6320 100644 --- a/python/llm/src/bigdl/llm/transformers/__init__.py +++ b/python/llm/src/bigdl/llm/transformers/__init__.py @@ -16,3 +16,4 @@ from .convert import ggml_convert_int4 from .model import AutoModelForCausalLM, AutoModel +from .modelling_bigdl import BigdlForCausalLM diff --git a/python/llm/src/bigdl/llm/transformers/modelling_bigdl.py b/python/llm/src/bigdl/llm/transformers/modelling_bigdl.py new file mode 100644 index 00000000..2b067a3b --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/modelling_bigdl.py @@ -0,0 +1,73 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This would makes sure Python is aware there is more than one sub-package within bigdl, +# physically located elsewhere. +# Otherwise there would be module not found error in non-pip's setting as Python would +# only search the first bigdl package and end up finding only one sub-package. + +from bigdl.llm.utils.common import invalidInputError + + +class BigdlForCausalLM: + """ + A generic model class that mimics the behavior of + ``transformers.LlamaForCausalLM.from_pretrained`` API + """ + + @classmethod + def from_pretrained(cls, + pretrained_model_name_or_path: str, + model_family: str = 'llama', + dtype: str = 'int4', + **kwargs): + """ + :param pretrained_model_name_or_path: Path for converted BigDL-LLM optimized ggml + binary checkpoint. The checkpoint should be converted by ``bigdl.llm.llm_convert``. + :param model_family: The model family of the pretrained checkpoint. + Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"`` and ``"starcoder"``. + :param dtype: Which quantized precision will be converted. + Now only `int4` and `int8` are supported, and `int8` only works for `llama` + , `gptneox` and `starcoder`. + :param cache_dir: (optional) This parameter will only be used when + ``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id. + It indicates the saving path for the converted low precision model. + :param tmp_path: (optional) Which path to store the intermediate fp16 model during the + conversion process. Default to `None` so that intermediate model will not be saved. + :param **kwargs: keyword arguments which will be passed to the model instance + + :return: a model instance + """ + invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder'], + "Now we only support model family: 'llama', 'gptneox', 'bloom'," + " 'starcoder', '{}' is not in the list.".format(model_family)) + invalidInputError(dtype.lower() in ['int4', 'int8'], + "Now we only support int4 and int8 as date type for weight") + + ggml_model_path = pretrained_model_name_or_path + + if model_family == 'llama': + from bigdl.llm.ggml.model.llama import Llama + return Llama(model_path=ggml_model_path, **kwargs) + elif model_family == 'gptneox': + from bigdl.llm.ggml.model.gptneox import Gptneox + return Gptneox(model_path=ggml_model_path, **kwargs) + elif model_family == 'bloom': + from bigdl.llm.ggml.model.bloom import Bloom + return Bloom(model_path=ggml_model_path, **kwargs) + elif model_family == 'starcoder': + from bigdl.llm.ggml.model.starcoder import Starcoder + return Starcoder(model_path=ggml_model_path, **kwargs)