LLM: Align converting GPTQ model API with transformer style (#8365)

* LLM: Align GPTQ API with transformer style
This commit is contained in:
Zhao Changmin 2023-06-20 14:27:41 +08:00 committed by GitHub
parent f99d348954
commit 4ec46afa4f
3 changed files with 31 additions and 20 deletions

View file

@ -39,7 +39,7 @@ Here is an example to use `llm-convert` command line tool.
# pth model
llm-convert "/path/to/llama-7b-hf/" --model-format pth --outfile "/path/to/llama-7b-int4/" --model-family "llama"
# gptq model
llm-convert "/path/to/vicuna-13B-1.1-GPTQ-4bit-128g.pt" --model-format gptq --outfile "/path/to/out.bin" --tokenizer-path "/path/to/tokenizer.model" --model-family "llama"
llm-convert "/path/to/vicuna-13B-1.1-GPTQ-4bit-128g/" --model-format gptq --outfile "/path/to/out.bin" --model-family "llama"
```
Here is an example to use `llm_convert` python API.

View file

@ -24,9 +24,7 @@ import argparse
def _special_kwarg_check(kwargs, check_args):
_used_args = {}
for arg in kwargs:
if arg not in check_args:
return False, {arg: kwargs[arg]}
else:
if arg in check_args:
_used_args[arg] = kwargs[arg]
return True, _used_args
@ -38,9 +36,8 @@ def llm_convert(model,
model_format="pth",
**kwargs):
if model_format == "pth":
check, _used_args = _special_kwarg_check(kwargs=kwargs,
_, _used_args = _special_kwarg_check(kwargs=kwargs,
check_args=["tmp_path"])
invalidInputError(check, f"Invaid input kwargs found: {_used_args}")
return ggml_convert_model(input_path=model,
output_path=outfile,
model_family=model_family,
@ -48,21 +45,15 @@ def llm_convert(model,
**_used_args,
)
elif model_format == "gptq":
invalidInputError(model.endswith(".pt"), "only support pytorch's .pt format now.")
invalidInputError(model_family == "llama" and outtype == 'int4',
"Convert GPTQ models should always "
"specify `--model-family llama --dtype int4` in the command line.")
check, _used_args = _special_kwarg_check(kwargs=kwargs,
_, _used_args = _special_kwarg_check(kwargs=kwargs,
check_args=["tokenizer_path"])
invalidInputError(check, f"Invaid input kwargs found: {_used_args}")
invalidInputError("tokenizer_path" in _used_args,
"The GPT-Q model requires the `tokenizer_path` parameter to be provided."
"Usage: convert-model --model-type gptq"
"--model-family llama --input-path llamaXXb-4bit.pt"
"--tokenizer-path tokenizer.model --output-path out.bin")
convert_gptq2ggml(input_path=model,
output_path=outfile,
tokenizer_path=_used_args["tokenizer_path"],
output_path=outfile)
)
return outfile
else:
invalidInputError(False, f"Unsupported input model_type: {model_format}")

View file

@ -30,6 +30,15 @@ from sentencepiece import SentencePieceProcessor
from bigdl.llm.utils.common.log4Error import invalidInputError
def find_pt_files(directory):
pt_files = []
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".pt"):
pt_files.append(os.path.join(root, file))
return pt_files
def write_header(fout, shape, dst_name, ftype_cur):
sname = dst_name.encode('utf-8')
fout.write(struct.pack("iii", len(shape), len(sname), ftype_cur))
@ -155,8 +164,13 @@ def convert_q4(src_name, dst_name, model, fout, n_head, permute=False):
blob.tofile(fout)
def convert_gptq2ggml(input_path, tokenizer_path, output_path):
model = torch.load(input_path, map_location="cpu")
def convert_gptq2ggml(input_path, output_path, tokenizer_path=None):
input_models = find_pt_files(input_path)
invalidInputError(len(input_models) == 1,
"Only support pytorch's .pt format now."
f"There should be only one .pt under {input_path}, "
f"but found {len(input_models)} file(s).")
model = torch.load(input_models[0], map_location="cpu")
n_vocab, n_embd = model['model.embed_tokens.weight'].shape
layer_re = r'model\.layers\.([0-9]+)'
@ -167,6 +181,12 @@ def convert_gptq2ggml(input_path, tokenizer_path, output_path):
n_mult = 256
n_head = {32: 32, 40: 40, 60: 52, 80: 64}[n_layer]
if not tokenizer_path:
tokenizer_path = os.path.join(input_path, "tokenizer.model")
invalidInputError(os.path.isfile(tokenizer_path),
f"tokenizer.model was not found under {tokenizer_path}."
f"Please specify the tokenizer-path")
tokenizer = SentencePieceProcessor(tokenizer_path)
invalidInputError(tokenizer.vocab_size() == n_vocab, "vocab size not match.")
@ -241,4 +261,4 @@ if __name__ == "__main__":
fname_tokenizer = sys.argv[2]
out_path = sys.argv[3]
invalidInputError(fname_model.endswith(".pt"), "only support pytorch's .pt format now.")
convert_gptq2ggml(fname_model, fname_tokenizer, out_path)
convert_gptq2ggml(fname_model, out_path, tokenizer_path=fname_tokenizer)