LLM: convert and quantize support for StarCoder (#8359)
* basic support for starcoder * update from_pretrained * fix bug and fix style
This commit is contained in:
parent
5f4f399ca7
commit
f99d348954
5 changed files with 206 additions and 19 deletions
|
|
@ -72,6 +72,10 @@ def _convert_bloom(model_path, outfile_dir, outtype):
|
|||
_convert_bloom_hf_to_ggml(model_path, outfile_dir, outtype)
|
||||
|
||||
|
||||
def _convert_starcoder(model_path, outfile_dir, outtype):
|
||||
_convert_starcoder_hf_to_ggml(model_path, outfile_dir, outtype)
|
||||
|
||||
|
||||
def _convert_to_ggml(model_path: str, outfile_dir: str,
|
||||
model_family: str = 'llama', outtype: str="fp16"):
|
||||
"""
|
||||
|
|
@ -84,12 +88,12 @@ def _convert_to_ggml(model_path: str, outfile_dir: str,
|
|||
For lora finetuned model, the path should be pointed to a merged weight.
|
||||
:param outfile_dir: str, the directory to save ggml compatible file, for example `./models`.
|
||||
:param model_family: Which model family your input model belongs to. Default to `llama`.
|
||||
Now only `llama`/`bloom`/`gptneox` are supported.
|
||||
Now only `llama`/`bloom`/`gptneox`/`starcoder` are supported.
|
||||
:param outtype: specify the output format. Defalut to `fp16`. Now `fp32`/`fp16` are supported.
|
||||
"""
|
||||
invalidInputError(model_family in ['llama', 'bloom', 'gptneox'],
|
||||
invalidInputError(model_family in ['llama', 'bloom', 'gptneox', 'starcoder'],
|
||||
"Now we only support quantization of model \
|
||||
family('llama', 'bloom', 'gptneox')",
|
||||
family('llama', 'bloom', 'gptneox', 'starcoder')",
|
||||
"{} is not in the list.".format(model_family))
|
||||
invalidInputError(os.path.exists(model_path),
|
||||
"The file {} was not found".format(model_path))
|
||||
|
|
@ -108,3 +112,5 @@ def _convert_to_ggml(model_path: str, outfile_dir: str,
|
|||
_convert_gptneox(model_path, outfile_dir, outtype)
|
||||
if model_family == 'bloom':
|
||||
_convert_bloom(model_path, outfile_dir, outtype)
|
||||
if model_family == 'starcoder':
|
||||
_convert_starcoder(model_path, outfile_dir, outtype)
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def convert_model(input_path: str,
|
|||
:param output_path: Save path of output quantized model. You must pass a *directory* to
|
||||
save all related output.
|
||||
:param model_family: Which model family your input model belongs to.
|
||||
Now only `llama`/`bloom`/`gptneox` are supported.
|
||||
Now only `llama`/`bloom`/`gptneox`/`starcoder` are supported.
|
||||
:param dtype: Which quantized precision will be converted.
|
||||
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
|
||||
and `gptneox`.
|
||||
|
|
@ -53,9 +53,9 @@ def convert_model(input_path: str,
|
|||
# make sure directory exists
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
# check input value
|
||||
invalidInputError(model_family in ['llama', 'bloom', 'gptneox'],
|
||||
invalidInputError(model_family in ['llama', 'bloom', 'gptneox', 'starcoder'],
|
||||
"Now we only support quantization of model \
|
||||
family('llama', 'bloom', 'gptneox')",
|
||||
family('llama', 'bloom', 'gptneox', 'starcoder')",
|
||||
"{} is not in the list.".format(model_family))
|
||||
invalidInputError(os.path.isdir(output_path),
|
||||
"The output_path {} was not a directory".format(output_path))
|
||||
|
|
@ -72,9 +72,9 @@ def convert_model(input_path: str,
|
|||
dtype = 'q4_0'
|
||||
elif dtype == 'int8':
|
||||
dtype = 'q8_0'
|
||||
invalidInputError(model_family in ['llama', 'gptneox'],
|
||||
invalidInputError(model_family in ['llama', 'gptneox', 'starcoder'],
|
||||
"Now we only support int8 quantization of model \
|
||||
family('llama', 'gptneox')",
|
||||
family('llama', 'gptneox', 'starcoder')",
|
||||
"{} is not in the list.".format(model_family))
|
||||
|
||||
if tmp_path is not None:
|
||||
|
|
@ -110,7 +110,7 @@ def main():
|
|||
help=("output_path,save path of output quantized model."))
|
||||
parser.add_argument('-x', '--model_family', type=str, required=True,
|
||||
help=("model_family: Which model family your input model belongs to."
|
||||
"Now only `llama`/`bloom`/`gptneox` are supported."))
|
||||
"Now only `llama`/`bloom`/`gptneox`/`starcoder` are supported."))
|
||||
parser.add_argument('-t', '--dtype', type=str, default="int4",
|
||||
help="Which quantized precision will be converted.")
|
||||
parser.add_argument('-p', '--tmp_path', type=str, default=None,
|
||||
|
|
|
|||
|
|
@ -36,10 +36,16 @@ _gptneox_quantize_type = {"q4_0": 2,
|
|||
"q5_0": 8,
|
||||
"q5_1": 9,
|
||||
"q8_0": 7}
|
||||
_starcoder_quantize_type = {"q4_0": 2,
|
||||
"q4_1": 3,
|
||||
"q5_0": 8,
|
||||
"q5_1": 9,
|
||||
"q8_0": 7}
|
||||
|
||||
_quantize_type = {"llama": _llama_quantize_type,
|
||||
"bloom": _bloom_quantize_type,
|
||||
"gptneox": _gptneox_quantize_type}
|
||||
"gptneox": _gptneox_quantize_type,
|
||||
"starcoder": _starcoder_quantize_type}
|
||||
|
||||
|
||||
def quantize(input_path: str, output_path: str,
|
||||
|
|
@ -52,19 +58,20 @@ def quantize(input_path: str, output_path: str,
|
|||
save all related output. Filename of quantized model will be like
|
||||
`bigdl_llm_llama_q4_0.bin`.
|
||||
:param model_family: Which model family your input model belongs to.
|
||||
Now only `llama`/`bloom`/`gptneox` are supported.
|
||||
Now only `llama`/`bloom`/`gptneox`/`starcoder` are supported.
|
||||
:param dtype: Quantization method which differs in the resulting model disk size and
|
||||
inference speed. Defalut to `q4_0`. Difference model family may support
|
||||
different types, now the supported list is:
|
||||
llama : "q4_0", "q4_1", "q4_2"
|
||||
bloom : "q4_0", "q4_1"
|
||||
gptneox : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
|
||||
starcoder : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
|
||||
|
||||
:return: the path str to the converted ggml binary checkpoint
|
||||
"""
|
||||
invalidInputError(model_family in ['llama', 'bloom', 'gptneox'],
|
||||
invalidInputError(model_family in ['llama', 'bloom', 'gptneox', 'starcoder'],
|
||||
"Now we only support quantization of model \
|
||||
family('llama', 'bloom', 'gptneox')",
|
||||
family('llama', 'bloom', 'gptneox', 'starcoder')",
|
||||
"{} is not in the list.".format(model_family))
|
||||
invalidInputError(os.path.isfile(input_path),
|
||||
"The file {} was not found".format(input_path))
|
||||
|
|
|
|||
|
|
@ -50,10 +50,10 @@ class AutoModelForCausalLM:
|
|||
3. a str for huggingface hub repo id.
|
||||
|
||||
:param model_family: the model family of the pretrained checkpoint.
|
||||
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"``.
|
||||
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"`` and ``"starcoder"``.
|
||||
:param dtype: Which quantized precision will be converted.
|
||||
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
|
||||
and `gptneox`.
|
||||
, `gptneox` and `starcoder`.
|
||||
:param cache_dir: (optional) this parameter will only be used when
|
||||
``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
|
||||
It indicates the saving path for the converted low precision model.
|
||||
|
|
@ -63,9 +63,9 @@ class AutoModelForCausalLM:
|
|||
|
||||
:return: a model instance
|
||||
"""
|
||||
invalidInputError(model_family in ['llama', 'gptneox', 'bloom'],
|
||||
invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder'],
|
||||
"Now we only support model family: 'llama', 'gptneox', 'bloom',"
|
||||
"'{}' is not in the list.".format(model_family))
|
||||
" 'starcoder', '{}' is not in the list.".format(model_family))
|
||||
invalidInputError(dtype.lower() in ['int4', 'int8'],
|
||||
"Now we only support int4 and int8 as date type for weight")
|
||||
|
||||
|
|
@ -110,3 +110,6 @@ class AutoModelForCausalLM:
|
|||
elif model_family == 'bloom':
|
||||
from bigdl.llm.ggml.model.bloom import Bloom
|
||||
return Bloom(model_path=ggml_model_path, **kwargs)
|
||||
elif model_family == 'starcoder':
|
||||
from bigdl.llm.ggml.model.starcoder import Starcoder
|
||||
return Starcoder(model_path=ggml_model_path, **kwargs)
|
||||
|
|
|
|||
|
|
@ -87,7 +87,8 @@ __all__ = ['Params',
|
|||
'load_vocab',
|
||||
'default_outfile',
|
||||
'_convert_gptneox_hf_to_ggml',
|
||||
'_convert_bloom_hf_to_ggml']
|
||||
'_convert_bloom_hf_to_ggml',
|
||||
'_convert_starcoder_hf_to_ggml']
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -1415,3 +1416,173 @@ def _convert_bloom_hf_to_ggml(model_path, outfile_dir, outtype):
|
|||
data.tofile(fout)
|
||||
|
||||
fout.close()
|
||||
|
||||
|
||||
def _convert_starcoder_hf_to_ggml(model_path, outfile_dir, outtype):
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
||||
import torch
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
hparams = config.to_dict()
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, config=config,
|
||||
torch_dtype=torch.float16
|
||||
if outtype == "f16" else torch.float32,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
offload_state_dict=True)
|
||||
|
||||
list_vars = model.state_dict()
|
||||
|
||||
encoder = tokenizer.vocab
|
||||
# Add added_tokens (special tokens) to the encoder
|
||||
encoder.update(tokenizer.get_added_vocab())
|
||||
|
||||
filestem = Path(model_path).stem
|
||||
fn_out = os.path.join(outfile_dir, f"ggml-{filestem}-{outtype}.bin")
|
||||
fout = open(fn_out, "wb")
|
||||
|
||||
if outtype == "f16":
|
||||
ftype = 1
|
||||
else:
|
||||
ftype = 0
|
||||
|
||||
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
|
||||
vocab_size = hparams["vocab_size"]
|
||||
fout.write(struct.pack("i", vocab_size))
|
||||
# fout.write(struct.pack("i", len(encoder)))
|
||||
fout.write(struct.pack("i", hparams["n_positions"]))
|
||||
fout.write(struct.pack("i", hparams["n_embd"]))
|
||||
fout.write(struct.pack("i", hparams["n_head"]))
|
||||
fout.write(struct.pack("i", hparams["n_layer"]))
|
||||
fout.write(struct.pack("i", ftype))
|
||||
|
||||
byte_encoder = bytes_to_unicode()
|
||||
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||
|
||||
fout.write(struct.pack("i", vocab_size))
|
||||
|
||||
counter = 0
|
||||
# sort by value
|
||||
for key in sorted(encoder, key=encoder.get):
|
||||
text = bytearray([byte_decoder[c] for c in key])
|
||||
fout.write(struct.pack("i", len(text)))
|
||||
fout.write(text)
|
||||
counter += 1
|
||||
|
||||
# TODO: Repeat last token until vocab_size
|
||||
while counter < vocab_size:
|
||||
fout.write(struct.pack("i", len(text)))
|
||||
fout.write(text)
|
||||
counter += 1
|
||||
|
||||
for name in list_vars.keys():
|
||||
data = list_vars[name].squeeze().numpy()
|
||||
print("Processing variable: " + name + " with shape: ", data.shape)
|
||||
|
||||
# rename headers to keep compatibility
|
||||
if name == "transformer.ln_f.weight":
|
||||
name = "model/ln_f/g"
|
||||
elif name == "transformer.ln_f.bias":
|
||||
name = "model/ln_f/b"
|
||||
elif name == "transformer.wte.weight":
|
||||
name = "model/wte"
|
||||
elif name == "transformer.wpe.weight":
|
||||
name = "model/wpe"
|
||||
elif name == "lm_head.weight":
|
||||
name = "model/lm_head"
|
||||
elif re.match(r"transformer.h\.\d+\.ln_1\.weight", name):
|
||||
i = re.findall("\d+", name)[0]
|
||||
name = f"model/h{i}/ln_1/g"
|
||||
elif re.match(r"transformer.h\.\d+\.ln_1\.bias", name):
|
||||
i = re.findall("\d+", name)[0]
|
||||
name = f"model/h{i}/ln_1/b"
|
||||
elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.weight", name):
|
||||
i = re.findall("\d+", name)[0]
|
||||
name = f"model/h{i}/attn/c_attn/w"
|
||||
elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.bias", name):
|
||||
i = re.findall("\d+", name)[0]
|
||||
name = f"model/h{i}/attn/c_attn/b"
|
||||
elif re.match(r"transformer.h\.\d+\.attn\.c_proj\.weight", name):
|
||||
i = re.findall("\d+", name)[0]
|
||||
name = f"model/h{i}/attn/c_proj/w"
|
||||
elif re.match(r"transformer.h.\d+.attn.c_proj.bias", name):
|
||||
i = re.findall("\d+", name)[0]
|
||||
name = f"model/h{i}/attn/c_proj/b"
|
||||
elif re.match(r"transformer.h.\d+.ln_2.weight", name):
|
||||
i = re.findall("\d+", name)[0]
|
||||
name = f"model/h{i}/ln_2/g"
|
||||
elif re.match(r"transformer.h.\d+.ln_2.bias", name):
|
||||
i = re.findall("\d+", name)[0]
|
||||
name = f"model/h{i}/ln_2/b"
|
||||
elif re.match(r"transformer.h.\d+.mlp.c_fc.weight", name):
|
||||
i = re.findall("\d+", name)[0]
|
||||
name = f"model/h{i}/mlp/c_fc/w"
|
||||
elif re.match(r"transformer.h.\d+.mlp.c_fc.bias", name):
|
||||
i = re.findall("\d+", name)[0]
|
||||
name = f"model/h{i}/mlp/c_fc/b"
|
||||
elif re.match(r"transformer.h.\d+.mlp.c_proj.weight", name):
|
||||
i = re.findall("\d+", name)[0]
|
||||
name = f"model/h{i}/mlp/c_proj/w"
|
||||
elif re.match(r"transformer.h.\d+.mlp.c_proj.bias", name):
|
||||
i = re.findall("\d+", name)[0]
|
||||
name = f"model/h{i}/mlp/c_proj/b"
|
||||
else:
|
||||
print("Unrecognized variable name. %s", name)
|
||||
|
||||
# we don't need these
|
||||
if name.endswith("attn.masked_bias") or name.endswith(".attn.bias"):
|
||||
print(" Skipping variable: " + name)
|
||||
continue
|
||||
|
||||
n_dims = len(data.shape)
|
||||
|
||||
ftype_cur = 0
|
||||
if ftype == 1:
|
||||
if (name == "model/wte" or name == "model/lm_head" or name[-2:] == "/g" or
|
||||
name[-2:] == "/w") and n_dims == 2:
|
||||
print(" Converting to float16")
|
||||
data = data.astype(np.float16)
|
||||
ftype_cur = 1
|
||||
else:
|
||||
print(" Converting to float32")
|
||||
data = data.astype(np.float32)
|
||||
ftype_cur = 0
|
||||
|
||||
"model/h.*/attn/c_attn/w"
|
||||
"model/h.*/attn/c_proj/w"
|
||||
"model/h.*/mlp/c_fc/w"
|
||||
"model/h.*/mlp/c_proj/w"
|
||||
if name[-14:] == "/attn/c_attn/w" or name[-14:] == "/attn/c_attn/b":
|
||||
print(" Duplicate K,V heads to use MHA instead of MQA")
|
||||
|
||||
embed_dim = hparams["n_embd"]
|
||||
head_dim = embed_dim // hparams["n_head"]
|
||||
|
||||
# ((n_heads + 2) * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim)
|
||||
q, k, v = np.split(data,
|
||||
(hparams["n_head"] * head_dim,
|
||||
(hparams["n_head"] + 1) * head_dim),
|
||||
axis=0)
|
||||
# duplicate k, v along the first axis (head_dim, hidden_dim) ->
|
||||
# (n_heads * head_dim, hidden_dim)
|
||||
if len(k.shape) == 2:
|
||||
k = np.tile(k, (hparams["n_head"], 1))
|
||||
v = np.tile(v, (hparams["n_head"], 1))
|
||||
elif len(k.shape) == 1:
|
||||
k = np.tile(k, (hparams["n_head"]))
|
||||
v = np.tile(v, (hparams["n_head"]))
|
||||
# concat q, k, v along the first axis (n_heads * head_dim, hidden_dim) ->
|
||||
# (3 * n_heads * head_dim, hidden_dim)
|
||||
data = np.concatenate((q, k, v), axis=0)
|
||||
|
||||
# header
|
||||
str = name.encode('utf-8')
|
||||
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
|
||||
for i in range(n_dims):
|
||||
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
|
||||
fout.write(str)
|
||||
|
||||
# data
|
||||
data.tofile(fout)
|
||||
|
||||
fout.close()
|
||||
|
|
|
|||
Loading…
Reference in a new issue