gptq2ggml: support loading safetensors model. (#8401)
* update convert gptq to ggml * update convert gptq to ggml * gptq to ggml * update script * meet code review * meet code review
This commit is contained in:
parent
b9eae23c79
commit
e68d631c0a
2 changed files with 34 additions and 21 deletions
|
|
@ -98,12 +98,13 @@ def llm_convert(model,
|
||||||
outtype.lower())
|
outtype.lower())
|
||||||
outfile = os.path.join(outfile, output_filename)
|
outfile = os.path.join(outfile, output_filename)
|
||||||
|
|
||||||
|
# TODO: delete this when support AutoTokenizer
|
||||||
if "tokenizer_path" in _used_args:
|
if "tokenizer_path" in _used_args:
|
||||||
gptq_tokenizer_path = _used_args["tokenizer_path"]
|
gptq_tokenizer_path = _used_args["tokenizer_path"]
|
||||||
else:
|
else:
|
||||||
gptq_tokenizer_path = None
|
gptq_tokenizer_path = None
|
||||||
|
|
||||||
convert_gptq2ggml(input_path=model,
|
convert_gptq2ggml(model_path=model,
|
||||||
output_path=outfile,
|
output_path=outfile,
|
||||||
tokenizer_path=gptq_tokenizer_path,
|
tokenizer_path=gptq_tokenizer_path,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -23,22 +23,15 @@ import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
|
import warnings
|
||||||
import struct
|
import struct
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
from pathlib import Path
|
||||||
from bigdl.llm.utils.common.log4Error import invalidInputError
|
from bigdl.llm.utils.common.log4Error import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
def find_pt_files(directory):
|
|
||||||
pt_files = []
|
|
||||||
for root, _, files in os.walk(directory):
|
|
||||||
for file in files:
|
|
||||||
if file.endswith(".pt"):
|
|
||||||
pt_files.append(os.path.join(root, file))
|
|
||||||
return pt_files
|
|
||||||
|
|
||||||
|
|
||||||
def write_header(fout, shape, dst_name, ftype_cur):
|
def write_header(fout, shape, dst_name, ftype_cur):
|
||||||
sname = dst_name.encode('utf-8')
|
sname = dst_name.encode('utf-8')
|
||||||
fout.write(struct.pack("iii", len(shape), len(sname), ftype_cur))
|
fout.write(struct.pack("iii", len(shape), len(sname), ftype_cur))
|
||||||
|
|
@ -105,6 +98,7 @@ def convert_q4(src_name, dst_name, model, fout, n_head, permute=False):
|
||||||
qweight = model[f"{src_name}.qweight"].numpy().T # transpose
|
qweight = model[f"{src_name}.qweight"].numpy().T # transpose
|
||||||
|
|
||||||
# Q4_1 does not support bias; good thing the bias is always all zeros.
|
# Q4_1 does not support bias; good thing the bias is always all zeros.
|
||||||
|
# Act-order is not supported.
|
||||||
invalidInputError(np.all(g_idx[:-1] <= g_idx[1:]),
|
invalidInputError(np.all(g_idx[:-1] <= g_idx[1:]),
|
||||||
"Act-order is not supported, please use a no act-order model.")
|
"Act-order is not supported, please use a no act-order model.")
|
||||||
ftype = 3 # Q4_1
|
ftype = 3 # Q4_1
|
||||||
|
|
@ -164,13 +158,27 @@ def convert_q4(src_name, dst_name, model, fout, n_head, permute=False):
|
||||||
blob.tofile(fout)
|
blob.tofile(fout)
|
||||||
|
|
||||||
|
|
||||||
def convert_gptq2ggml(input_path, output_path, tokenizer_path=None):
|
def find_quantized_model_file(model_path):
|
||||||
input_models = find_pt_files(input_path)
|
model_path = Path(model_path)
|
||||||
invalidInputError(len(input_models) == 1,
|
for ext in ['.safetensors', '.pt']:
|
||||||
"Only support pytorch's .pt format now."
|
found = list(model_path.glob(f"*{ext}"))
|
||||||
f"There should be only one .pt under {input_path}, "
|
if len(found) > 0:
|
||||||
f"but found {len(input_models)} file(s).")
|
if len(found) != 1:
|
||||||
model = torch.load(input_models[0], map_location="cpu")
|
warnings.warn(f'Detected {len(found)} {ext} model, use the first one {found[0]}.')
|
||||||
|
print(f"Detected model file {found[0]}")
|
||||||
|
return str(found[0])
|
||||||
|
|
||||||
|
|
||||||
|
def convert_gptq2ggml(model_path, output_path, tokenizer_path=None):
|
||||||
|
input_path = find_quantized_model_file(model_path)
|
||||||
|
|
||||||
|
if input_path.endswith('pt'):
|
||||||
|
model = torch.load(input_path, map_location="cpu")
|
||||||
|
elif input_path.endswith('safetensors'):
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
model = load_file(input_path)
|
||||||
|
else:
|
||||||
|
invalidInputError(False, "unknown input model path, only support .safetensors or .pt file.")
|
||||||
|
|
||||||
n_vocab, n_embd = model['model.embed_tokens.weight'].shape
|
n_vocab, n_embd = model['model.embed_tokens.weight'].shape
|
||||||
layer_re = r'model\.layers\.([0-9]+)'
|
layer_re = r'model\.layers\.([0-9]+)'
|
||||||
|
|
@ -182,14 +190,19 @@ def convert_gptq2ggml(input_path, output_path, tokenizer_path=None):
|
||||||
n_head = {32: 32, 40: 40, 60: 52, 80: 64}[n_layer]
|
n_head = {32: 32, 40: 40, 60: 52, 80: 64}[n_layer]
|
||||||
|
|
||||||
if not tokenizer_path:
|
if not tokenizer_path:
|
||||||
tokenizer_path = os.path.join(input_path, "tokenizer.model")
|
tokenizer_path = os.path.join(model_path, "tokenizer.model")
|
||||||
invalidInputError(os.path.isfile(tokenizer_path),
|
invalidInputError(os.path.isfile(tokenizer_path),
|
||||||
f"tokenizer.model was not found under {tokenizer_path}."
|
f"tokenizer.model was not found under {tokenizer_path}."
|
||||||
f"Please specify the tokenizer-path")
|
f"Please specify the tokenizer-path")
|
||||||
|
|
||||||
tokenizer = SentencePieceProcessor(tokenizer_path)
|
tokenizer = SentencePieceProcessor(tokenizer_path)
|
||||||
|
vocab_size = tokenizer.vocab_size()
|
||||||
|
# TODO: Support AutoTokenizer
|
||||||
|
# from transformers import AutoTokenizer
|
||||||
|
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||||
|
# vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
invalidInputError(tokenizer.vocab_size() == n_vocab, "vocab size not match.")
|
invalidInputError(vocab_size <= n_vocab, "vocab size not match.")
|
||||||
|
|
||||||
fout = open(output_path, "wb")
|
fout = open(output_path, "wb")
|
||||||
|
|
||||||
|
|
@ -205,7 +218,7 @@ def convert_gptq2ggml(input_path, output_path, tokenizer_path=None):
|
||||||
fout.write(struct.pack("i" * len(values), *values))
|
fout.write(struct.pack("i" * len(values), *values))
|
||||||
|
|
||||||
# This loop unchanged from convert-pth-to-ggml.py:
|
# This loop unchanged from convert-pth-to-ggml.py:
|
||||||
for i in range(tokenizer.vocab_size()):
|
for i in range(vocab_size):
|
||||||
if tokenizer.is_unknown(i):
|
if tokenizer.is_unknown(i):
|
||||||
text = " \u2047 ".encode("utf-8")
|
text = " \u2047 ".encode("utf-8")
|
||||||
elif tokenizer.is_control(i):
|
elif tokenizer.is_control(i):
|
||||||
|
|
@ -260,5 +273,4 @@ if __name__ == "__main__":
|
||||||
fname_model = sys.argv[1]
|
fname_model = sys.argv[1]
|
||||||
fname_tokenizer = sys.argv[2]
|
fname_tokenizer = sys.argv[2]
|
||||||
out_path = sys.argv[3]
|
out_path = sys.argv[3]
|
||||||
invalidInputError(fname_model.endswith(".pt"), "only support pytorch's .pt format now.")
|
|
||||||
convert_gptq2ggml(fname_model, out_path, tokenizer_path=fname_tokenizer)
|
convert_gptq2ggml(fname_model, out_path, tokenizer_path=fname_tokenizer)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue