[LLM] Supports GPTQ convert in transfomers-like API, and supports folder outfile for llm-convert (#8366)

* Add docstrings to llm_convert

* Small docstrings fix

* Unify outfile type to be a folder path for either gptq or pth model_format

* Supports gptq model input for from_pretrained

* Fix example and readme

* Small fix

* Python style fix

* Bug fix in llm_convert

* Python style check

* Fix based on comments

* Small fix
This commit is contained in:
Yuwen Hu 2023-06-20 17:42:38 +08:00 committed by GitHub
parent 4ec46afa4f
commit 7ef1c890eb
6 changed files with 102 additions and 34 deletions

View file

@ -39,8 +39,9 @@ Here is an example to use `llm-convert` command line tool.
# pth model
llm-convert "/path/to/llama-7b-hf/" --model-format pth --outfile "/path/to/llama-7b-int4/" --model-family "llama"
# gptq model
llm-convert "/path/to/vicuna-13B-1.1-GPTQ-4bit-128g/" --model-format gptq --outfile "/path/to/out.bin" --model-family "llama"
llm-convert "/path/to/vicuna-13B-1.1-GPTQ-4bit-128g/" --model-format gptq --outfile "/path/to/vicuna-13B-int4/" --model-family "llama"
```
> An example GPTQ model can be found [here](https://huggingface.co/TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g/tree/main)
Here is an example to use `llm_convert` python API.
```bash

View file

@ -40,13 +40,13 @@ def convert_and_load(repo_id_or_model_path, model_family, n_threads):
# to convert the downloaded Huggungface checkpoint first,
# and then load the binary checkpoint directly.
#
# from bigdl.llm.ggml import llm_convert
# from bigdl.llm import llm_convert
#
# model_path = repo_id_or_model_path
# output_ckpt_path = llm_convert(
# input_path=model_path,
# output_path='./',
# dtype='int4',
# model=model_path,
# outfile='./',
# outtype='int4',
# model_family=model_family)
#
# llm = AutoModelForCausalLM.from_pretrained(

View file

@ -19,6 +19,7 @@ from bigdl.llm.ggml.convert_model import convert_model as ggml_convert_model
from bigdl.llm.gptq.convert.convert_gptq_to_ggml import convert_gptq2ggml
from bigdl.llm.utils.common import invalidInputError
import argparse
import os
def _special_kwarg_check(kwargs, check_args):
@ -35,6 +36,46 @@ def llm_convert(model,
outtype='int4',
model_format="pth",
**kwargs):
"""
This function is able to:
1. Convert Hugging Face llama-like / gpt-neox-like / bloom-like / starcoder-like
PyTorch model to lower precision in BigDL-LLM optimized GGML format.
2. Convert Hugging Face GPTQ format llama-like model to BigDL-LLM optimized
GGML format.
:param model: Path to a **directory**:
1. If ``model_format='pth'``, the folder should be a Hugging Face checkpoint
that is directly pulled from Hugging Face hub, for example ``./llama-7b-hf``.
This should be a dir path that contains: weight bin, tokenizer config,
tokenizer.model (required for llama) and added_tokens.json (if applied).
For lora finetuned model, the path should be pointed to a merged weight.
2. If ``model_format='gptq'``, the folder should be be a Hugging Face checkpoint
in GPTQ format, which contains weights in pytorch's .pt format,
and ``tokenizer.model``.
:param outfile: Save path of output quantized model. You must pass a **directory** to
save all related output.
:param model_family: Which model family your input model belongs to.
Now ``llama``/``bloom``/``gptneox``/``starcoder`` has been supported.
If ``model_format='gptq'``, only ``llama`` is supported.
:param dtype: Which quantized precision will be converted.
If ``model_format='pth'``, `int4` and `int8` are supported,
meanwhile `int8` only works for `llama` and `gptneox`.
If ``model_format='gptq'``, only ``int4`` is supported.
:param model_format: Specify the model format to be converted. ``pth`` is for
PyTorch model checkpoint from Hugging Face. ``gptq`` is for GPTQ format
model from Hugging Face.
:param **kwargs: Supported keyword arguments includes:
* ``tmp_path``: Valid when ``model_format='pth'``. It refers to the path
that stores the intermediate model during the conversion process.
* ``tokenizer_path``: Valid when ``model_format='gptq'``. It refers to the path
where ``tokenizer.model`` is located (if it is not in the ``model`` directory)
:return: the path string to the converted lower precision checkpoint.
"""
if model_format == "pth":
_, _used_args = _special_kwarg_check(kwargs=kwargs,
check_args=["tmp_path"])
@ -48,11 +89,23 @@ def llm_convert(model,
invalidInputError(model_family == "llama" and outtype == 'int4',
"Convert GPTQ models should always "
"specify `--model-family llama --dtype int4` in the command line.")
invalidInputError(os.path.isdir(outfile),
"The output_path {} is not a directory".format(outfile))
_, _used_args = _special_kwarg_check(kwargs=kwargs,
check_args=["tokenizer_path"])
output_filename = "bigdl_llm_{}_{}_from_gptq.bin".format(model_family,
outtype.lower())
outfile = os.path.join(outfile, output_filename)
if "tokenizer_path" in _used_args:
gptq_tokenizer_path = _used_args["tokenizer_path"]
else:
gptq_tokenizer_path = None
convert_gptq2ggml(input_path=model,
output_path=outfile,
tokenizer_path=_used_args["tokenizer_path"],
tokenizer_path=gptq_tokenizer_path,
)
return outfile
else:

View file

@ -29,17 +29,18 @@ def convert_model(input_path: str,
dtype: str = 'int4',
tmp_path: str = None):
"""
Convert Hugging Face llama-like / gpt-neox-like / bloom-like model to lower precision
Convert Hugging Face llama-like / gpt-neox-like / bloom-like / starcoder-like
PyTorch model to lower precision
:param input_path: Path to a *directory* for huggingface checkpoint that are directly
:param input_path: Path to a **directory** for huggingface checkpoint that is directly
pulled from huggingface hub, for example `./llama-7b-hf`. This should be a dir
path that contains: weight bin, tokenizer config, tokenizer.model (required for
llama) and added_tokens.json (if applied).
For lora finetuned model, the path should be pointed to a merged weight.
:param output_path: Save path of output quantized model. You must pass a *directory* to
:param output_path: Save path of output quantized model. You must pass a **directory** to
save all related output.
:param model_family: Which model family your input model belongs to.
Now only `llama`/`bloom`/`gptneox`/`starcoder` are supported.
Now only ``llama``/``bloom``/``gptneox``/``starcoder`` are supported.
:param dtype: Which quantized precision will be converted.
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
and `gptneox`.

View file

@ -74,9 +74,9 @@ def quantize(input_path: str, output_path: str,
family('llama', 'bloom', 'gptneox', 'starcoder')",
"{} is not in the list.".format(model_family))
invalidInputError(os.path.isfile(input_path),
"The file {} was not found".format(input_path))
"The file {} is not found".format(input_path))
invalidInputError(os.path.isdir(output_path),
"The output_path {} was not a directory".format(output_path))
"The output_path {} is not a directory".format(output_path))
# convert quantize type str into corresponding int value
quantize_type_map = _quantize_type[model_family]
output_filename = "bigdl_llm_{}_{}.bin".format(model_family,

View file

@ -33,6 +33,7 @@ class AutoModelForCausalLM:
@classmethod
def from_pretrained(cls,
pretrained_model_name_or_path: str,
model_format: str = 'pth',
model_family: str = 'llama',
dtype: str = 'int4',
cache_dir: str = './',
@ -41,20 +42,30 @@ class AutoModelForCausalLM:
"""
:param pretrained_model_name_or_path: We support 3 kinds of pretrained model checkpoint
1. path for huggingface checkpoint that are directly pulled from huggingface hub.
This should be a dir path that contains: weight bin, tokenizer config,
tokenizer.model (required for llama) and added_tokens.json (if applied).
For lora fine tuned model, the path should be pointed to a merged weight.
2. path for converted ggml binary checkpoint. The checkpoint should be converted by
``bigdl.llm.ggml.convert_model``.
3. a str for huggingface hub repo id.
1. Path to directory for Hugging Face checkpoint that are directly pulled from
Hugging Face hub.
:param model_family: the model family of the pretrained checkpoint.
If ``model_format='pth'``, the folder should contain: weight bin, tokenizer
config, tokenizer.model (required for llama) and added_tokens.json (if applied).
For lora fine tuned model, the path should be pointed to a merged weight.
If ``model_format='gptq'``, the folder should be be a Hugging Face checkpoint
in GPTQ format, which contains weights in pytorch's .pt format,
and ``tokenizer.model``.
2. Path for converted BigDL-LLM optimized ggml binary checkpoint.
The checkpoint should be converted by ``bigdl.llm.llm_convert``.
3. A str for Hugging Face hub repo id.
:param model_format: Specify the model format to be converted. ``pth`` is for
PyTorch model checkpoint from Hugging Face. ``gptq`` is for GPTQ format
model from Hugging Face.
:param model_family: The model family of the pretrained checkpoint.
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"`` and ``"starcoder"``.
:param dtype: Which quantized precision will be converted.
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
, `gptneox` and `starcoder`.
:param cache_dir: (optional) this parameter will only be used when
:param cache_dir: (optional) This parameter will only be used when
``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
It indicates the saving path for the converted low precision model.
:param tmp_path: (optional) Which path to store the intermediate fp16 model during the
@ -73,7 +84,7 @@ class AutoModelForCausalLM:
# if not, it is likely that the user wants to pass in the repo id.
if not os.path.exists(pretrained_model_name_or_path):
try:
# download from huggingface based on repo id
# download from Hugging Face based on repo id
from huggingface_hub import snapshot_download
pretrained_model_name_or_path = snapshot_download(
repo_id=pretrained_model_name_or_path)
@ -82,24 +93,26 @@ class AutoModelForCausalLM:
# if downloading fails, it could be the case that repo id is invalid,
# or the user pass in the wrong path for checkpoint
invalidInputError(False,
"Downloadng from huggingface repo id {} failed. "
"Please input valid huggingface hub repo id, "
"or provide the valid path to huggingface / "
"ggml binary checkpoint, for pretrained_model_name_or_path"
"Downloadng from Hugging Face repo id {} failed. "
"Please input valid Hugging Face hub repo id, "
"or provide the valid path to Hugging Face / "
"BigDL-LLM optimized ggml binary checkpoint, "
"for pretrained_model_name_or_path"
.format(pretrained_model_name_or_path))
ggml_model_path = pretrained_model_name_or_path
# check whether pretrained_model_name_or_path is a file.
# if not, it is likely that pretrained_model_name_or_path
# points to a huggingface checkpoint
# points to a Hugging Face checkpoint
if not os.path.isfile(pretrained_model_name_or_path):
# huggingface checkpoint
from bigdl.llm.ggml import convert_model
ggml_model_path = convert_model(input_path=pretrained_model_name_or_path,
output_path=cache_dir,
model_family=model_family,
dtype=dtype,
tmp_path=tmp_path)
# Hugging Face checkpoint
from bigdl.llm import llm_convert
ggml_model_path = llm_convert(model=pretrained_model_name_or_path,
outfile=cache_dir,
model_family=model_family,
outtype=dtype,
model_format=model_format,
tmp_path=tmp_path)
if model_family == 'llama':
from bigdl.llm.ggml.model.llama import Llama