diff --git a/python/llm/README.md b/python/llm/README.md index b5b04c91..277cc37f 100644 --- a/python/llm/README.md +++ b/python/llm/README.md @@ -32,20 +32,30 @@ A standard procedure for using `bigdl-llm` contains 3 steps: 3. Inference using `llm-cli`, transformers like API, or `langchain`. ### Convert your model -A python function and a command line tool `convert_model` is provided to transform the model from huggingface format to GGML format. +A python function and a command line tool `llm-convert` is provided to transform the model from huggingface format to GGML format. -Here is an example to use `convert_model` command line tool. +Here is an example to use `llm-convert` command line tool. ```bash -convert_model -i "/path/to/llama-7b-hf/" -o "/path/to/llama-7b-int4/" -x "llama" +# pth model +llm-convert "/path/to/llama-7b-hf/" --model-format pth --outfile "/path/to/llama-7b-int4/" --model-family "llama" +# gptq model +llm-convert "/path/to/vicuna-13B-1.1-GPTQ-4bit-128g.pt" --model-format gptq -outfile "/path/to/out.bin" --tokenizer-path "/path/to/tokenizer.model" --model-family "llama" ``` -Here is an example to use `convert_model` python API. +Here is an example to use `llm_convert` python API. ```bash -from bigdl.llm.ggml import convert_model - -convert_model(input_path="/path/to/llama-7b-hf/", - output_path="/path/to/llama-7b-int4/", - model_family="llama") +from bigdl.llm import llm_convert +# pth model +llm_convert(model="/path/to/llama-7b-hf/", + outfile="/path/to/llama-7b-int4/", + model_format="pth", + model_family="llama") +# gptq model +llm_convert(model="/path/to/vicuna-13B-1.1-GPTQ-4bit-128g.pt", + outfile="/path/to/out.bin", + model_format="gptq", + tokenizer_path="/path/to/tokenizer.model", + model_family="llama") ``` ### Inferencing diff --git a/python/llm/example/transformers/int4_pipeline.py b/python/llm/example/transformers/int4_pipeline.py index 9adfa5d1..05b30ff5 100644 --- a/python/llm/example/transformers/int4_pipeline.py +++ b/python/llm/example/transformers/int4_pipeline.py @@ -36,14 +36,14 @@ def convert_and_load(repo_id_or_model_path, model_family, n_threads): cache_dir='./', n_threads=n_threads) - # if you want to explicitly convert the pre-trained model, you can use the `convert_model` API + # if you want to explicitly convert the pre-trained model, you can use the `llm_convert` API # to convert the downloaded Huggungface checkpoint first, # and then load the binary checkpoint directly. # - # from bigdl.llm.ggml import convert_model + # from bigdl.llm.ggml import llm_convert # # model_path = repo_id_or_model_path - # output_ckpt_path = convert_model( + # output_ckpt_path = llm_convert( # input_path=model_path, # output_path='./', # dtype='int4', diff --git a/python/llm/setup.py b/python/llm/setup.py index 59066f11..e0aa06e6 100644 --- a/python/llm/setup.py +++ b/python/llm/setup.py @@ -207,7 +207,7 @@ def setup_package(): include_package_data=True, entry_points={ "console_scripts": [ - 'convert_model=bigdl.llm.ggml.convert_model:main' + 'llm-convert=bigdl.llm.convert_model:main' ] }, extras_require={"all": all_requires}, diff --git a/python/llm/src/bigdl/llm/__init__.py b/python/llm/src/bigdl/llm/__init__.py index dbdafd2a..a9c71a16 100644 --- a/python/llm/src/bigdl/llm/__init__.py +++ b/python/llm/src/bigdl/llm/__init__.py @@ -18,3 +18,5 @@ # physically located elsewhere. # Otherwise there would be module not found error in non-pip's setting as Python would # only search the first bigdl package and end up finding only one sub-package. + +from .convert_model import llm_convert diff --git a/python/llm/src/bigdl/llm/convert_model.py b/python/llm/src/bigdl/llm/convert_model.py new file mode 100644 index 00000000..2da473ed --- /dev/null +++ b/python/llm/src/bigdl/llm/convert_model.py @@ -0,0 +1,95 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from bigdl.llm.ggml.convert_model import convert_model as ggml_convert_model +from bigdl.llm.gptq.convert.convert_gptq_to_ggml import convert_gptq2ggml +from bigdl.llm.utils.common import invalidInputError +import argparse + + +def _special_kwarg_check(kwargs, check_args): + _used_args = {} + for arg in check_args: + if arg not in kwargs: + return False, {arg, kwargs[arg]} + else: + _used_args[arg] = kwargs[arg] + return True, _used_args + + +def llm_convert(model, + outfile, + model_family, + outtype='int4', + model_format="pth", + **kwargs): + if model_format == "pth": + check, _used_args = _special_kwarg_check(kwargs=kwargs, + check_args=["tmp_path"]) + invalidInputError(check, f"Invaid input kwargs found: {_used_args}") + ggml_convert_model(input_path=model, + output_path=outfile, + model_family=model_family, + dtype=outtype, + **_used_args, + ) + elif model_format == "gptq": + invalidInputError(model.endswith(".pt"), "only support pytorch's .pt format now.") + invalidInputError(model_family == "llama" and outtype == 'int4', + "Convert GPTQ models should always " + "specify `--model-family llama --dtype int4` in the command line.") + check, _used_args = _special_kwarg_check(kwargs=kwargs, + check_args=["tokenizer_path"]) + invalidInputError(check, f"Invaid input kwargs found: {_used_args}") + invalidInputError("tokenizer_path" in _used_args, + "The GPT-Q model requires the `tokenizer_path` parameter to be provided." + "Usage: convert-model --model-type gptq" + "--model-family llama --input-path llamaXXb-4bit.pt" + "--tokenizer-path tokenizer.model --output-path out.bin") + convert_gptq2ggml(input_path=model, + tokenizer_path=_used_args["tokenizer_path"], + output_path=outfile) + else: + invalidInputError(False, f"Unsupported input model_type: {model_format}") + + +def main(): + parser = argparse.ArgumentParser(description='Model Convert Parameters') + parser.add_argument('model', type=str, + help=("model, a path to a *directory* containing model weights")) + parser.add_argument('-o', '--outfile', type=str, required=True, + help=("outfile,save path of output quantized model.")) + parser.add_argument('-x', '--model-family', type=str, required=True, + help=("--model-family: Which model family your input model belongs to." + "Now only `llama`/`bloom`/`gptneox` are supported.")) + parser.add_argument('-f', '--model-format', type=str, required=True, + help=("The model type to be convert to a ggml compatible file." + "Now only `pth`/`gptq` are supported.")) + parser.add_argument('-t', '--outtype', type=str, default="int4", + help="Which quantized precision will be converted.") + + # pth specific args + parser.add_argument('-p', '--tmp-path', type=str, default=None, + help="Which path to store the intermediate model during the" + "conversion process.") + + # gptq specific args + parser.add_argument('-k', '--tokenizer-path', type=str, default=None, + help="tokenizer_path, a path of tokenizer.model") + args = parser.parse_args() + params = vars(args) + llm_convert(**params) diff --git a/python/llm/src/bigdl/llm/gptq/__init__.py b/python/llm/src/bigdl/llm/gptq/__init__.py new file mode 100644 index 00000000..dbdafd2a --- /dev/null +++ b/python/llm/src/bigdl/llm/gptq/__init__.py @@ -0,0 +1,20 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This would makes sure Python is aware there is more than one sub-package within bigdl, +# physically located elsewhere. +# Otherwise there would be module not found error in non-pip's setting as Python would +# only search the first bigdl package and end up finding only one sub-package. diff --git a/python/llm/src/bigdl/llm/gptq/convert/__init__.py b/python/llm/src/bigdl/llm/gptq/convert/__init__.py new file mode 100644 index 00000000..dbdafd2a --- /dev/null +++ b/python/llm/src/bigdl/llm/gptq/convert/__init__.py @@ -0,0 +1,20 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This would makes sure Python is aware there is more than one sub-package within bigdl, +# physically located elsewhere. +# Otherwise there would be module not found error in non-pip's setting as Python would +# only search the first bigdl package and end up finding only one sub-package. diff --git a/python/llm/src/bigdl/llm/gptq/convert/convert-gptq-to-ggml.py b/python/llm/src/bigdl/llm/gptq/convert/convert_gptq_to_ggml.py similarity index 98% rename from python/llm/src/bigdl/llm/gptq/convert/convert-gptq-to-ggml.py rename to python/llm/src/bigdl/llm/gptq/convert/convert_gptq_to_ggml.py index 82f6b084..ec8df047 100644 --- a/python/llm/src/bigdl/llm/gptq/convert/convert-gptq-to-ggml.py +++ b/python/llm/src/bigdl/llm/gptq/convert/convert_gptq_to_ggml.py @@ -155,8 +155,8 @@ def convert_q4(src_name, dst_name, model, fout, n_head, permute=False): blob.tofile(fout) -def convert_gptq2ggml(model_path, tokenizer_path, output_path): - model = torch.load(model_path, map_location="cpu") +def convert_gptq2ggml(input_path, tokenizer_path, output_path): + model = torch.load(input_path, map_location="cpu") n_vocab, n_embd = model['model.embed_tokens.weight'].shape layer_re = r'model\.layers\.([0-9]+)' diff --git a/python/llm/test/convert/test_convert_model.py b/python/llm/test/convert/test_convert_model.py index 125ecf9c..6a39801c 100644 --- a/python/llm/test/convert/test_convert_model.py +++ b/python/llm/test/convert/test_convert_model.py @@ -19,7 +19,7 @@ import pytest import os from unittest import TestCase -from bigdl.llm.ggml import convert_model +from bigdl.llm import llm_convert llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH') @@ -30,24 +30,27 @@ output_dir = os.environ.get('INT4_CKPT_DIR') class TestConvertModel(TestCase): def test_convert_llama(self): - converted_model_path = convert_model(input_path=llama_model_path, - output_path=output_dir, - model_family='llama', - dtype='int4') + converted_model_path = llm_convert(model=llama_model_path, + outfile=output_dir, + model_family='llama', + model_format="pth", + outtype='int4') assert os.path.isfile(converted_model_path) def test_convert_gptneox(self): - converted_model_path = convert_model(input_path=gptneox_model_path, - output_path=output_dir, - model_family='gptneox', - dtype='int4') + converted_model_path = llm_convert(model=gptneox_model_path, + outfile=output_dir, + model_family='gptneox', + model_format="pth", + outtype='int4') assert os.path.isfile(converted_model_path) def test_convert_bloom(self): - converted_model_path = convert_model(input_path=bloom_model_path, - output_path=output_dir, - model_family='bloom', - dtype='int4') + converted_model_path = llm_convert(model=bloom_model_path, + outfile=output_dir, + model_family='bloom', + model_format="pth", + outtype='int4') assert os.path.isfile(converted_model_path) if __name__ == '__main__':