gptq2ggml: support loading safetensors model. (#8401)

* update convert gptq to ggml

* update convert gptq to ggml

* gptq to ggml

* update script

* meet code review

* meet code review
This commit is contained in:
Xin Qiu 2023-06-27 11:19:33 +08:00 committed by GitHub
parent b9eae23c79
commit e68d631c0a
2 changed files with 34 additions and 21 deletions

View file

@ -98,12 +98,13 @@ def llm_convert(model,
outtype.lower())
outfile = os.path.join(outfile, output_filename)
# TODO: delete this when support AutoTokenizer
if "tokenizer_path" in _used_args:
gptq_tokenizer_path = _used_args["tokenizer_path"]
else:
gptq_tokenizer_path = None
convert_gptq2ggml(input_path=model,
convert_gptq2ggml(model_path=model,
output_path=outfile,
tokenizer_path=gptq_tokenizer_path,
)

View file

@ -23,22 +23,15 @@ import os
import re
import sys
import json
import warnings
import struct
import numpy as np
import torch
from sentencepiece import SentencePieceProcessor
from pathlib import Path
from bigdl.llm.utils.common.log4Error import invalidInputError
def find_pt_files(directory):
pt_files = []
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".pt"):
pt_files.append(os.path.join(root, file))
return pt_files
def write_header(fout, shape, dst_name, ftype_cur):
sname = dst_name.encode('utf-8')
fout.write(struct.pack("iii", len(shape), len(sname), ftype_cur))
@ -105,6 +98,7 @@ def convert_q4(src_name, dst_name, model, fout, n_head, permute=False):
qweight = model[f"{src_name}.qweight"].numpy().T # transpose
# Q4_1 does not support bias; good thing the bias is always all zeros.
# Act-order is not supported.
invalidInputError(np.all(g_idx[:-1] <= g_idx[1:]),
"Act-order is not supported, please use a no act-order model.")
ftype = 3 # Q4_1
@ -164,13 +158,27 @@ def convert_q4(src_name, dst_name, model, fout, n_head, permute=False):
blob.tofile(fout)
def convert_gptq2ggml(input_path, output_path, tokenizer_path=None):
input_models = find_pt_files(input_path)
invalidInputError(len(input_models) == 1,
"Only support pytorch's .pt format now."
f"There should be only one .pt under {input_path}, "
f"but found {len(input_models)} file(s).")
model = torch.load(input_models[0], map_location="cpu")
def find_quantized_model_file(model_path):
model_path = Path(model_path)
for ext in ['.safetensors', '.pt']:
found = list(model_path.glob(f"*{ext}"))
if len(found) > 0:
if len(found) != 1:
warnings.warn(f'Detected {len(found)} {ext} model, use the first one {found[0]}.')
print(f"Detected model file {found[0]}")
return str(found[0])
def convert_gptq2ggml(model_path, output_path, tokenizer_path=None):
input_path = find_quantized_model_file(model_path)
if input_path.endswith('pt'):
model = torch.load(input_path, map_location="cpu")
elif input_path.endswith('safetensors'):
from safetensors.torch import load_file
model = load_file(input_path)
else:
invalidInputError(False, "unknown input model path, only support .safetensors or .pt file.")
n_vocab, n_embd = model['model.embed_tokens.weight'].shape
layer_re = r'model\.layers\.([0-9]+)'
@ -182,14 +190,19 @@ def convert_gptq2ggml(input_path, output_path, tokenizer_path=None):
n_head = {32: 32, 40: 40, 60: 52, 80: 64}[n_layer]
if not tokenizer_path:
tokenizer_path = os.path.join(input_path, "tokenizer.model")
tokenizer_path = os.path.join(model_path, "tokenizer.model")
invalidInputError(os.path.isfile(tokenizer_path),
f"tokenizer.model was not found under {tokenizer_path}."
f"Please specify the tokenizer-path")
tokenizer = SentencePieceProcessor(tokenizer_path)
vocab_size = tokenizer.vocab_size()
# TODO: Support AutoTokenizer
# from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# vocab_size = tokenizer.vocab_size
invalidInputError(tokenizer.vocab_size() == n_vocab, "vocab size not match.")
invalidInputError(vocab_size <= n_vocab, "vocab size not match.")
fout = open(output_path, "wb")
@ -205,7 +218,7 @@ def convert_gptq2ggml(input_path, output_path, tokenizer_path=None):
fout.write(struct.pack("i" * len(values), *values))
# This loop unchanged from convert-pth-to-ggml.py:
for i in range(tokenizer.vocab_size()):
for i in range(vocab_size):
if tokenizer.is_unknown(i):
text = " \u2047 ".encode("utf-8")
elif tokenizer.is_control(i):
@ -260,5 +273,4 @@ if __name__ == "__main__":
fname_model = sys.argv[1]
fname_tokenizer = sys.argv[2]
out_path = sys.argv[3]
invalidInputError(fname_model.endswith(".pt"), "only support pytorch's .pt format now.")
convert_gptq2ggml(fname_model, out_path, tokenizer_path=fname_tokenizer)