LLM: convert and quantize support for StarCoder (#8359)
* basic support for starcoder * update from_pretrained * fix bug and fix style
This commit is contained in:
parent
5f4f399ca7
commit
f99d348954
5 changed files with 206 additions and 19 deletions
|
|
@ -72,6 +72,10 @@ def _convert_bloom(model_path, outfile_dir, outtype):
|
||||||
_convert_bloom_hf_to_ggml(model_path, outfile_dir, outtype)
|
_convert_bloom_hf_to_ggml(model_path, outfile_dir, outtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_starcoder(model_path, outfile_dir, outtype):
|
||||||
|
_convert_starcoder_hf_to_ggml(model_path, outfile_dir, outtype)
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_ggml(model_path: str, outfile_dir: str,
|
def _convert_to_ggml(model_path: str, outfile_dir: str,
|
||||||
model_family: str = 'llama', outtype: str="fp16"):
|
model_family: str = 'llama', outtype: str="fp16"):
|
||||||
"""
|
"""
|
||||||
|
|
@ -84,12 +88,12 @@ def _convert_to_ggml(model_path: str, outfile_dir: str,
|
||||||
For lora finetuned model, the path should be pointed to a merged weight.
|
For lora finetuned model, the path should be pointed to a merged weight.
|
||||||
:param outfile_dir: str, the directory to save ggml compatible file, for example `./models`.
|
:param outfile_dir: str, the directory to save ggml compatible file, for example `./models`.
|
||||||
:param model_family: Which model family your input model belongs to. Default to `llama`.
|
:param model_family: Which model family your input model belongs to. Default to `llama`.
|
||||||
Now only `llama`/`bloom`/`gptneox` are supported.
|
Now only `llama`/`bloom`/`gptneox`/`starcoder` are supported.
|
||||||
:param outtype: specify the output format. Defalut to `fp16`. Now `fp32`/`fp16` are supported.
|
:param outtype: specify the output format. Defalut to `fp16`. Now `fp32`/`fp16` are supported.
|
||||||
"""
|
"""
|
||||||
invalidInputError(model_family in ['llama', 'bloom', 'gptneox'],
|
invalidInputError(model_family in ['llama', 'bloom', 'gptneox', 'starcoder'],
|
||||||
"Now we only support quantization of model \
|
"Now we only support quantization of model \
|
||||||
family('llama', 'bloom', 'gptneox')",
|
family('llama', 'bloom', 'gptneox', 'starcoder')",
|
||||||
"{} is not in the list.".format(model_family))
|
"{} is not in the list.".format(model_family))
|
||||||
invalidInputError(os.path.exists(model_path),
|
invalidInputError(os.path.exists(model_path),
|
||||||
"The file {} was not found".format(model_path))
|
"The file {} was not found".format(model_path))
|
||||||
|
|
@ -108,3 +112,5 @@ def _convert_to_ggml(model_path: str, outfile_dir: str,
|
||||||
_convert_gptneox(model_path, outfile_dir, outtype)
|
_convert_gptneox(model_path, outfile_dir, outtype)
|
||||||
if model_family == 'bloom':
|
if model_family == 'bloom':
|
||||||
_convert_bloom(model_path, outfile_dir, outtype)
|
_convert_bloom(model_path, outfile_dir, outtype)
|
||||||
|
if model_family == 'starcoder':
|
||||||
|
_convert_starcoder(model_path, outfile_dir, outtype)
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ def convert_model(input_path: str,
|
||||||
:param output_path: Save path of output quantized model. You must pass a *directory* to
|
:param output_path: Save path of output quantized model. You must pass a *directory* to
|
||||||
save all related output.
|
save all related output.
|
||||||
:param model_family: Which model family your input model belongs to.
|
:param model_family: Which model family your input model belongs to.
|
||||||
Now only `llama`/`bloom`/`gptneox` are supported.
|
Now only `llama`/`bloom`/`gptneox`/`starcoder` are supported.
|
||||||
:param dtype: Which quantized precision will be converted.
|
:param dtype: Which quantized precision will be converted.
|
||||||
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
|
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
|
||||||
and `gptneox`.
|
and `gptneox`.
|
||||||
|
|
@ -53,9 +53,9 @@ def convert_model(input_path: str,
|
||||||
# make sure directory exists
|
# make sure directory exists
|
||||||
os.makedirs(output_path, exist_ok=True)
|
os.makedirs(output_path, exist_ok=True)
|
||||||
# check input value
|
# check input value
|
||||||
invalidInputError(model_family in ['llama', 'bloom', 'gptneox'],
|
invalidInputError(model_family in ['llama', 'bloom', 'gptneox', 'starcoder'],
|
||||||
"Now we only support quantization of model \
|
"Now we only support quantization of model \
|
||||||
family('llama', 'bloom', 'gptneox')",
|
family('llama', 'bloom', 'gptneox', 'starcoder')",
|
||||||
"{} is not in the list.".format(model_family))
|
"{} is not in the list.".format(model_family))
|
||||||
invalidInputError(os.path.isdir(output_path),
|
invalidInputError(os.path.isdir(output_path),
|
||||||
"The output_path {} was not a directory".format(output_path))
|
"The output_path {} was not a directory".format(output_path))
|
||||||
|
|
@ -72,9 +72,9 @@ def convert_model(input_path: str,
|
||||||
dtype = 'q4_0'
|
dtype = 'q4_0'
|
||||||
elif dtype == 'int8':
|
elif dtype == 'int8':
|
||||||
dtype = 'q8_0'
|
dtype = 'q8_0'
|
||||||
invalidInputError(model_family in ['llama', 'gptneox'],
|
invalidInputError(model_family in ['llama', 'gptneox', 'starcoder'],
|
||||||
"Now we only support int8 quantization of model \
|
"Now we only support int8 quantization of model \
|
||||||
family('llama', 'gptneox')",
|
family('llama', 'gptneox', 'starcoder')",
|
||||||
"{} is not in the list.".format(model_family))
|
"{} is not in the list.".format(model_family))
|
||||||
|
|
||||||
if tmp_path is not None:
|
if tmp_path is not None:
|
||||||
|
|
@ -110,7 +110,7 @@ def main():
|
||||||
help=("output_path,save path of output quantized model."))
|
help=("output_path,save path of output quantized model."))
|
||||||
parser.add_argument('-x', '--model_family', type=str, required=True,
|
parser.add_argument('-x', '--model_family', type=str, required=True,
|
||||||
help=("model_family: Which model family your input model belongs to."
|
help=("model_family: Which model family your input model belongs to."
|
||||||
"Now only `llama`/`bloom`/`gptneox` are supported."))
|
"Now only `llama`/`bloom`/`gptneox`/`starcoder` are supported."))
|
||||||
parser.add_argument('-t', '--dtype', type=str, default="int4",
|
parser.add_argument('-t', '--dtype', type=str, default="int4",
|
||||||
help="Which quantized precision will be converted.")
|
help="Which quantized precision will be converted.")
|
||||||
parser.add_argument('-p', '--tmp_path', type=str, default=None,
|
parser.add_argument('-p', '--tmp_path', type=str, default=None,
|
||||||
|
|
|
||||||
|
|
@ -36,10 +36,16 @@ _gptneox_quantize_type = {"q4_0": 2,
|
||||||
"q5_0": 8,
|
"q5_0": 8,
|
||||||
"q5_1": 9,
|
"q5_1": 9,
|
||||||
"q8_0": 7}
|
"q8_0": 7}
|
||||||
|
_starcoder_quantize_type = {"q4_0": 2,
|
||||||
|
"q4_1": 3,
|
||||||
|
"q5_0": 8,
|
||||||
|
"q5_1": 9,
|
||||||
|
"q8_0": 7}
|
||||||
|
|
||||||
_quantize_type = {"llama": _llama_quantize_type,
|
_quantize_type = {"llama": _llama_quantize_type,
|
||||||
"bloom": _bloom_quantize_type,
|
"bloom": _bloom_quantize_type,
|
||||||
"gptneox": _gptneox_quantize_type}
|
"gptneox": _gptneox_quantize_type,
|
||||||
|
"starcoder": _starcoder_quantize_type}
|
||||||
|
|
||||||
|
|
||||||
def quantize(input_path: str, output_path: str,
|
def quantize(input_path: str, output_path: str,
|
||||||
|
|
@ -52,19 +58,20 @@ def quantize(input_path: str, output_path: str,
|
||||||
save all related output. Filename of quantized model will be like
|
save all related output. Filename of quantized model will be like
|
||||||
`bigdl_llm_llama_q4_0.bin`.
|
`bigdl_llm_llama_q4_0.bin`.
|
||||||
:param model_family: Which model family your input model belongs to.
|
:param model_family: Which model family your input model belongs to.
|
||||||
Now only `llama`/`bloom`/`gptneox` are supported.
|
Now only `llama`/`bloom`/`gptneox`/`starcoder` are supported.
|
||||||
:param dtype: Quantization method which differs in the resulting model disk size and
|
:param dtype: Quantization method which differs in the resulting model disk size and
|
||||||
inference speed. Defalut to `q4_0`. Difference model family may support
|
inference speed. Defalut to `q4_0`. Difference model family may support
|
||||||
different types, now the supported list is:
|
different types, now the supported list is:
|
||||||
llama : "q4_0", "q4_1", "q4_2"
|
llama : "q4_0", "q4_1", "q4_2"
|
||||||
bloom : "q4_0", "q4_1"
|
bloom : "q4_0", "q4_1"
|
||||||
gptneox : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
|
gptneox : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
|
||||||
|
starcoder : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
|
||||||
|
|
||||||
:return: the path str to the converted ggml binary checkpoint
|
:return: the path str to the converted ggml binary checkpoint
|
||||||
"""
|
"""
|
||||||
invalidInputError(model_family in ['llama', 'bloom', 'gptneox'],
|
invalidInputError(model_family in ['llama', 'bloom', 'gptneox', 'starcoder'],
|
||||||
"Now we only support quantization of model \
|
"Now we only support quantization of model \
|
||||||
family('llama', 'bloom', 'gptneox')",
|
family('llama', 'bloom', 'gptneox', 'starcoder')",
|
||||||
"{} is not in the list.".format(model_family))
|
"{} is not in the list.".format(model_family))
|
||||||
invalidInputError(os.path.isfile(input_path),
|
invalidInputError(os.path.isfile(input_path),
|
||||||
"The file {} was not found".format(input_path))
|
"The file {} was not found".format(input_path))
|
||||||
|
|
|
||||||
|
|
@ -50,10 +50,10 @@ class AutoModelForCausalLM:
|
||||||
3. a str for huggingface hub repo id.
|
3. a str for huggingface hub repo id.
|
||||||
|
|
||||||
:param model_family: the model family of the pretrained checkpoint.
|
:param model_family: the model family of the pretrained checkpoint.
|
||||||
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"``.
|
Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"`` and ``"starcoder"``.
|
||||||
:param dtype: Which quantized precision will be converted.
|
:param dtype: Which quantized precision will be converted.
|
||||||
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
|
Now only `int4` and `int8` are supported, and `int8` only works for `llama`
|
||||||
and `gptneox`.
|
, `gptneox` and `starcoder`.
|
||||||
:param cache_dir: (optional) this parameter will only be used when
|
:param cache_dir: (optional) this parameter will only be used when
|
||||||
``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
|
``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
|
||||||
It indicates the saving path for the converted low precision model.
|
It indicates the saving path for the converted low precision model.
|
||||||
|
|
@ -63,9 +63,9 @@ class AutoModelForCausalLM:
|
||||||
|
|
||||||
:return: a model instance
|
:return: a model instance
|
||||||
"""
|
"""
|
||||||
invalidInputError(model_family in ['llama', 'gptneox', 'bloom'],
|
invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder'],
|
||||||
"Now we only support model family: 'llama', 'gptneox', 'bloom',"
|
"Now we only support model family: 'llama', 'gptneox', 'bloom',"
|
||||||
"'{}' is not in the list.".format(model_family))
|
" 'starcoder', '{}' is not in the list.".format(model_family))
|
||||||
invalidInputError(dtype.lower() in ['int4', 'int8'],
|
invalidInputError(dtype.lower() in ['int4', 'int8'],
|
||||||
"Now we only support int4 and int8 as date type for weight")
|
"Now we only support int4 and int8 as date type for weight")
|
||||||
|
|
||||||
|
|
@ -110,3 +110,6 @@ class AutoModelForCausalLM:
|
||||||
elif model_family == 'bloom':
|
elif model_family == 'bloom':
|
||||||
from bigdl.llm.ggml.model.bloom import Bloom
|
from bigdl.llm.ggml.model.bloom import Bloom
|
||||||
return Bloom(model_path=ggml_model_path, **kwargs)
|
return Bloom(model_path=ggml_model_path, **kwargs)
|
||||||
|
elif model_family == 'starcoder':
|
||||||
|
from bigdl.llm.ggml.model.starcoder import Starcoder
|
||||||
|
return Starcoder(model_path=ggml_model_path, **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,8 @@ __all__ = ['Params',
|
||||||
'load_vocab',
|
'load_vocab',
|
||||||
'default_outfile',
|
'default_outfile',
|
||||||
'_convert_gptneox_hf_to_ggml',
|
'_convert_gptneox_hf_to_ggml',
|
||||||
'_convert_bloom_hf_to_ggml']
|
'_convert_bloom_hf_to_ggml',
|
||||||
|
'_convert_starcoder_hf_to_ggml']
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
@ -1415,3 +1416,173 @@ def _convert_bloom_hf_to_ggml(model_path, outfile_dir, outtype):
|
||||||
data.tofile(fout)
|
data.tofile(fout)
|
||||||
|
|
||||||
fout.close()
|
fout.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_starcoder_hf_to_ggml(model_path, outfile_dir, outtype):
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
||||||
|
import torch
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
hparams = config.to_dict()
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path, config=config,
|
||||||
|
torch_dtype=torch.float16
|
||||||
|
if outtype == "f16" else torch.float32,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
offload_state_dict=True)
|
||||||
|
|
||||||
|
list_vars = model.state_dict()
|
||||||
|
|
||||||
|
encoder = tokenizer.vocab
|
||||||
|
# Add added_tokens (special tokens) to the encoder
|
||||||
|
encoder.update(tokenizer.get_added_vocab())
|
||||||
|
|
||||||
|
filestem = Path(model_path).stem
|
||||||
|
fn_out = os.path.join(outfile_dir, f"ggml-{filestem}-{outtype}.bin")
|
||||||
|
fout = open(fn_out, "wb")
|
||||||
|
|
||||||
|
if outtype == "f16":
|
||||||
|
ftype = 1
|
||||||
|
else:
|
||||||
|
ftype = 0
|
||||||
|
|
||||||
|
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
|
||||||
|
vocab_size = hparams["vocab_size"]
|
||||||
|
fout.write(struct.pack("i", vocab_size))
|
||||||
|
# fout.write(struct.pack("i", len(encoder)))
|
||||||
|
fout.write(struct.pack("i", hparams["n_positions"]))
|
||||||
|
fout.write(struct.pack("i", hparams["n_embd"]))
|
||||||
|
fout.write(struct.pack("i", hparams["n_head"]))
|
||||||
|
fout.write(struct.pack("i", hparams["n_layer"]))
|
||||||
|
fout.write(struct.pack("i", ftype))
|
||||||
|
|
||||||
|
byte_encoder = bytes_to_unicode()
|
||||||
|
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||||
|
|
||||||
|
fout.write(struct.pack("i", vocab_size))
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
# sort by value
|
||||||
|
for key in sorted(encoder, key=encoder.get):
|
||||||
|
text = bytearray([byte_decoder[c] for c in key])
|
||||||
|
fout.write(struct.pack("i", len(text)))
|
||||||
|
fout.write(text)
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
# TODO: Repeat last token until vocab_size
|
||||||
|
while counter < vocab_size:
|
||||||
|
fout.write(struct.pack("i", len(text)))
|
||||||
|
fout.write(text)
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
for name in list_vars.keys():
|
||||||
|
data = list_vars[name].squeeze().numpy()
|
||||||
|
print("Processing variable: " + name + " with shape: ", data.shape)
|
||||||
|
|
||||||
|
# rename headers to keep compatibility
|
||||||
|
if name == "transformer.ln_f.weight":
|
||||||
|
name = "model/ln_f/g"
|
||||||
|
elif name == "transformer.ln_f.bias":
|
||||||
|
name = "model/ln_f/b"
|
||||||
|
elif name == "transformer.wte.weight":
|
||||||
|
name = "model/wte"
|
||||||
|
elif name == "transformer.wpe.weight":
|
||||||
|
name = "model/wpe"
|
||||||
|
elif name == "lm_head.weight":
|
||||||
|
name = "model/lm_head"
|
||||||
|
elif re.match(r"transformer.h\.\d+\.ln_1\.weight", name):
|
||||||
|
i = re.findall("\d+", name)[0]
|
||||||
|
name = f"model/h{i}/ln_1/g"
|
||||||
|
elif re.match(r"transformer.h\.\d+\.ln_1\.bias", name):
|
||||||
|
i = re.findall("\d+", name)[0]
|
||||||
|
name = f"model/h{i}/ln_1/b"
|
||||||
|
elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.weight", name):
|
||||||
|
i = re.findall("\d+", name)[0]
|
||||||
|
name = f"model/h{i}/attn/c_attn/w"
|
||||||
|
elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.bias", name):
|
||||||
|
i = re.findall("\d+", name)[0]
|
||||||
|
name = f"model/h{i}/attn/c_attn/b"
|
||||||
|
elif re.match(r"transformer.h\.\d+\.attn\.c_proj\.weight", name):
|
||||||
|
i = re.findall("\d+", name)[0]
|
||||||
|
name = f"model/h{i}/attn/c_proj/w"
|
||||||
|
elif re.match(r"transformer.h.\d+.attn.c_proj.bias", name):
|
||||||
|
i = re.findall("\d+", name)[0]
|
||||||
|
name = f"model/h{i}/attn/c_proj/b"
|
||||||
|
elif re.match(r"transformer.h.\d+.ln_2.weight", name):
|
||||||
|
i = re.findall("\d+", name)[0]
|
||||||
|
name = f"model/h{i}/ln_2/g"
|
||||||
|
elif re.match(r"transformer.h.\d+.ln_2.bias", name):
|
||||||
|
i = re.findall("\d+", name)[0]
|
||||||
|
name = f"model/h{i}/ln_2/b"
|
||||||
|
elif re.match(r"transformer.h.\d+.mlp.c_fc.weight", name):
|
||||||
|
i = re.findall("\d+", name)[0]
|
||||||
|
name = f"model/h{i}/mlp/c_fc/w"
|
||||||
|
elif re.match(r"transformer.h.\d+.mlp.c_fc.bias", name):
|
||||||
|
i = re.findall("\d+", name)[0]
|
||||||
|
name = f"model/h{i}/mlp/c_fc/b"
|
||||||
|
elif re.match(r"transformer.h.\d+.mlp.c_proj.weight", name):
|
||||||
|
i = re.findall("\d+", name)[0]
|
||||||
|
name = f"model/h{i}/mlp/c_proj/w"
|
||||||
|
elif re.match(r"transformer.h.\d+.mlp.c_proj.bias", name):
|
||||||
|
i = re.findall("\d+", name)[0]
|
||||||
|
name = f"model/h{i}/mlp/c_proj/b"
|
||||||
|
else:
|
||||||
|
print("Unrecognized variable name. %s", name)
|
||||||
|
|
||||||
|
# we don't need these
|
||||||
|
if name.endswith("attn.masked_bias") or name.endswith(".attn.bias"):
|
||||||
|
print(" Skipping variable: " + name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
n_dims = len(data.shape)
|
||||||
|
|
||||||
|
ftype_cur = 0
|
||||||
|
if ftype == 1:
|
||||||
|
if (name == "model/wte" or name == "model/lm_head" or name[-2:] == "/g" or
|
||||||
|
name[-2:] == "/w") and n_dims == 2:
|
||||||
|
print(" Converting to float16")
|
||||||
|
data = data.astype(np.float16)
|
||||||
|
ftype_cur = 1
|
||||||
|
else:
|
||||||
|
print(" Converting to float32")
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
ftype_cur = 0
|
||||||
|
|
||||||
|
"model/h.*/attn/c_attn/w"
|
||||||
|
"model/h.*/attn/c_proj/w"
|
||||||
|
"model/h.*/mlp/c_fc/w"
|
||||||
|
"model/h.*/mlp/c_proj/w"
|
||||||
|
if name[-14:] == "/attn/c_attn/w" or name[-14:] == "/attn/c_attn/b":
|
||||||
|
print(" Duplicate K,V heads to use MHA instead of MQA")
|
||||||
|
|
||||||
|
embed_dim = hparams["n_embd"]
|
||||||
|
head_dim = embed_dim // hparams["n_head"]
|
||||||
|
|
||||||
|
# ((n_heads + 2) * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim)
|
||||||
|
q, k, v = np.split(data,
|
||||||
|
(hparams["n_head"] * head_dim,
|
||||||
|
(hparams["n_head"] + 1) * head_dim),
|
||||||
|
axis=0)
|
||||||
|
# duplicate k, v along the first axis (head_dim, hidden_dim) ->
|
||||||
|
# (n_heads * head_dim, hidden_dim)
|
||||||
|
if len(k.shape) == 2:
|
||||||
|
k = np.tile(k, (hparams["n_head"], 1))
|
||||||
|
v = np.tile(v, (hparams["n_head"], 1))
|
||||||
|
elif len(k.shape) == 1:
|
||||||
|
k = np.tile(k, (hparams["n_head"]))
|
||||||
|
v = np.tile(v, (hparams["n_head"]))
|
||||||
|
# concat q, k, v along the first axis (n_heads * head_dim, hidden_dim) ->
|
||||||
|
# (3 * n_heads * head_dim, hidden_dim)
|
||||||
|
data = np.concatenate((q, k, v), axis=0)
|
||||||
|
|
||||||
|
# header
|
||||||
|
str = name.encode('utf-8')
|
||||||
|
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
|
||||||
|
for i in range(n_dims):
|
||||||
|
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
|
||||||
|
fout.write(str)
|
||||||
|
|
||||||
|
# data
|
||||||
|
data.tofile(fout)
|
||||||
|
|
||||||
|
fout.close()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue