LLM: convert and quantize support for StarCoder (#8359)
* basic support for starcoder * update from_pretrained * fix bug and fix style
This commit is contained in:
		
							parent
							
								
									5f4f399ca7
								
							
						
					
					
						commit
						f99d348954
					
				
					 5 changed files with 206 additions and 19 deletions
				
			
		| 
						 | 
				
			
			@ -72,6 +72,10 @@ def _convert_bloom(model_path, outfile_dir, outtype):
 | 
			
		|||
    _convert_bloom_hf_to_ggml(model_path, outfile_dir, outtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _convert_starcoder(model_path, outfile_dir, outtype):
 | 
			
		||||
    _convert_starcoder_hf_to_ggml(model_path, outfile_dir, outtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _convert_to_ggml(model_path: str, outfile_dir: str,
 | 
			
		||||
                     model_family: str = 'llama', outtype: str="fp16"):
 | 
			
		||||
    """
 | 
			
		||||
| 
						 | 
				
			
			@ -84,12 +88,12 @@ def _convert_to_ggml(model_path: str, outfile_dir: str,
 | 
			
		|||
            For lora finetuned model, the path should be pointed to a merged weight.
 | 
			
		||||
    :param outfile_dir: str, the directory to save ggml compatible file, for example `./models`.
 | 
			
		||||
    :param model_family: Which model family your input model belongs to. Default to `llama`.
 | 
			
		||||
            Now only `llama`/`bloom`/`gptneox` are supported.
 | 
			
		||||
            Now only `llama`/`bloom`/`gptneox`/`starcoder` are supported.
 | 
			
		||||
    :param outtype: specify the output format. Defalut to `fp16`. Now `fp32`/`fp16` are supported.
 | 
			
		||||
    """
 | 
			
		||||
    invalidInputError(model_family in ['llama', 'bloom', 'gptneox'],
 | 
			
		||||
    invalidInputError(model_family in ['llama', 'bloom', 'gptneox', 'starcoder'],
 | 
			
		||||
                      "Now we only support quantization of model \
 | 
			
		||||
                       family('llama', 'bloom', 'gptneox')",
 | 
			
		||||
                       family('llama', 'bloom', 'gptneox', 'starcoder')",
 | 
			
		||||
                      "{} is not in the list.".format(model_family))
 | 
			
		||||
    invalidInputError(os.path.exists(model_path),
 | 
			
		||||
                      "The file {} was not found".format(model_path))
 | 
			
		||||
| 
						 | 
				
			
			@ -108,3 +112,5 @@ def _convert_to_ggml(model_path: str, outfile_dir: str,
 | 
			
		|||
        _convert_gptneox(model_path, outfile_dir, outtype)
 | 
			
		||||
    if model_family == 'bloom':
 | 
			
		||||
        _convert_bloom(model_path, outfile_dir, outtype)
 | 
			
		||||
    if model_family == 'starcoder':
 | 
			
		||||
        _convert_starcoder(model_path, outfile_dir, outtype)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -39,7 +39,7 @@ def convert_model(input_path: str,
 | 
			
		|||
    :param output_path: Save path of output quantized model. You must pass a *directory* to
 | 
			
		||||
            save all related output.
 | 
			
		||||
    :param model_family: Which model family your input model belongs to.
 | 
			
		||||
            Now only `llama`/`bloom`/`gptneox` are supported.
 | 
			
		||||
            Now only `llama`/`bloom`/`gptneox`/`starcoder` are supported.
 | 
			
		||||
    :param dtype: Which quantized precision will be converted.
 | 
			
		||||
            Now only `int4` and `int8` are supported, and `int8` only works for `llama`
 | 
			
		||||
            and `gptneox`.
 | 
			
		||||
| 
						 | 
				
			
			@ -53,9 +53,9 @@ def convert_model(input_path: str,
 | 
			
		|||
    # make sure directory exists
 | 
			
		||||
    os.makedirs(output_path, exist_ok=True)
 | 
			
		||||
    # check input value
 | 
			
		||||
    invalidInputError(model_family in ['llama', 'bloom', 'gptneox'],
 | 
			
		||||
    invalidInputError(model_family in ['llama', 'bloom', 'gptneox', 'starcoder'],
 | 
			
		||||
                      "Now we only support quantization of model \
 | 
			
		||||
                       family('llama', 'bloom', 'gptneox')",
 | 
			
		||||
                       family('llama', 'bloom', 'gptneox', 'starcoder')",
 | 
			
		||||
                      "{} is not in the list.".format(model_family))
 | 
			
		||||
    invalidInputError(os.path.isdir(output_path),
 | 
			
		||||
                      "The output_path {} was not a directory".format(output_path))
 | 
			
		||||
| 
						 | 
				
			
			@ -72,9 +72,9 @@ def convert_model(input_path: str,
 | 
			
		|||
        dtype = 'q4_0'
 | 
			
		||||
    elif dtype == 'int8':
 | 
			
		||||
        dtype = 'q8_0'
 | 
			
		||||
        invalidInputError(model_family in ['llama', 'gptneox'],
 | 
			
		||||
        invalidInputError(model_family in ['llama', 'gptneox', 'starcoder'],
 | 
			
		||||
                          "Now we only support int8 quantization of model \
 | 
			
		||||
                          family('llama', 'gptneox')",
 | 
			
		||||
                          family('llama', 'gptneox', 'starcoder')",
 | 
			
		||||
                          "{} is not in the list.".format(model_family))
 | 
			
		||||
 | 
			
		||||
    if tmp_path is not None:
 | 
			
		||||
| 
						 | 
				
			
			@ -110,7 +110,7 @@ def main():
 | 
			
		|||
                        help=("output_path,save path of output quantized model."))
 | 
			
		||||
    parser.add_argument('-x', '--model_family', type=str, required=True,
 | 
			
		||||
                        help=("model_family: Which model family your input model belongs to."
 | 
			
		||||
                              "Now only `llama`/`bloom`/`gptneox` are supported."))
 | 
			
		||||
                              "Now only `llama`/`bloom`/`gptneox`/`starcoder` are supported."))
 | 
			
		||||
    parser.add_argument('-t', '--dtype', type=str, default="int4",
 | 
			
		||||
                        help="Which quantized precision will be converted.")
 | 
			
		||||
    parser.add_argument('-p', '--tmp_path', type=str, default=None,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -36,10 +36,16 @@ _gptneox_quantize_type = {"q4_0": 2,
 | 
			
		|||
                          "q5_0": 8,
 | 
			
		||||
                          "q5_1": 9,
 | 
			
		||||
                          "q8_0": 7}
 | 
			
		||||
_starcoder_quantize_type = {"q4_0": 2,
 | 
			
		||||
                            "q4_1": 3,
 | 
			
		||||
                            "q5_0": 8,
 | 
			
		||||
                            "q5_1": 9,
 | 
			
		||||
                            "q8_0": 7}
 | 
			
		||||
 | 
			
		||||
_quantize_type = {"llama": _llama_quantize_type,
 | 
			
		||||
                  "bloom": _bloom_quantize_type,
 | 
			
		||||
                  "gptneox": _gptneox_quantize_type}
 | 
			
		||||
                  "gptneox": _gptneox_quantize_type,
 | 
			
		||||
                  "starcoder": _starcoder_quantize_type}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def quantize(input_path: str, output_path: str,
 | 
			
		||||
| 
						 | 
				
			
			@ -52,19 +58,20 @@ def quantize(input_path: str, output_path: str,
 | 
			
		|||
            save all related output. Filename of quantized model will be like
 | 
			
		||||
            `bigdl_llm_llama_q4_0.bin`.
 | 
			
		||||
    :param model_family: Which model family your input model belongs to.
 | 
			
		||||
            Now only `llama`/`bloom`/`gptneox` are supported.
 | 
			
		||||
            Now only `llama`/`bloom`/`gptneox`/`starcoder` are supported.
 | 
			
		||||
    :param dtype: Quantization method which differs in the resulting model disk size and
 | 
			
		||||
            inference speed. Defalut to `q4_0`. Difference model family may support
 | 
			
		||||
            different types, now the supported list is:
 | 
			
		||||
            llama : "q4_0", "q4_1", "q4_2"
 | 
			
		||||
            bloom : "q4_0", "q4_1"
 | 
			
		||||
            gptneox : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
 | 
			
		||||
            starcoder : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
 | 
			
		||||
 | 
			
		||||
    :return: the path str to the converted ggml binary checkpoint
 | 
			
		||||
    """
 | 
			
		||||
    invalidInputError(model_family in ['llama', 'bloom', 'gptneox'],
 | 
			
		||||
    invalidInputError(model_family in ['llama', 'bloom', 'gptneox', 'starcoder'],
 | 
			
		||||
                      "Now we only support quantization of model \
 | 
			
		||||
                       family('llama', 'bloom', 'gptneox')",
 | 
			
		||||
                       family('llama', 'bloom', 'gptneox', 'starcoder')",
 | 
			
		||||
                      "{} is not in the list.".format(model_family))
 | 
			
		||||
    invalidInputError(os.path.isfile(input_path),
 | 
			
		||||
                      "The file {} was not found".format(input_path))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,10 +50,10 @@ class AutoModelForCausalLM:
 | 
			
		|||
               3. a str for huggingface hub repo id.
 | 
			
		||||
 | 
			
		||||
        :param model_family: the model family of the pretrained checkpoint.
 | 
			
		||||
               Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"``.
 | 
			
		||||
               Currently we support ``"llama"``, ``"bloom"``, ``"gptneox"`` and ``"starcoder"``.
 | 
			
		||||
        :param dtype: Which quantized precision will be converted.
 | 
			
		||||
                Now only `int4` and `int8` are supported, and `int8` only works for `llama`
 | 
			
		||||
                and `gptneox`.
 | 
			
		||||
                , `gptneox` and `starcoder`.
 | 
			
		||||
        :param cache_dir: (optional) this parameter will only be used when
 | 
			
		||||
               ``pretrained_model_name_or_path`` is a hugginface checkpoint or hub repo id.
 | 
			
		||||
               It indicates the saving path for the converted low precision model.
 | 
			
		||||
| 
						 | 
				
			
			@ -63,9 +63,9 @@ class AutoModelForCausalLM:
 | 
			
		|||
 | 
			
		||||
        :return: a model instance
 | 
			
		||||
        """
 | 
			
		||||
        invalidInputError(model_family in ['llama', 'gptneox', 'bloom'],
 | 
			
		||||
                          "Now we only support model family: 'llama', 'gptneox', 'bloom', "
 | 
			
		||||
                          "'{}' is not in the list.".format(model_family))
 | 
			
		||||
        invalidInputError(model_family in ['llama', 'gptneox', 'bloom', 'starcoder'],
 | 
			
		||||
                          "Now we only support model family: 'llama', 'gptneox', 'bloom',"
 | 
			
		||||
                          " 'starcoder', '{}' is not in the list.".format(model_family))
 | 
			
		||||
        invalidInputError(dtype.lower() in ['int4', 'int8'],
 | 
			
		||||
                          "Now we only support int4 and int8 as date type for weight")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -110,3 +110,6 @@ class AutoModelForCausalLM:
 | 
			
		|||
        elif model_family == 'bloom':
 | 
			
		||||
            from bigdl.llm.ggml.model.bloom import Bloom
 | 
			
		||||
            return Bloom(model_path=ggml_model_path, **kwargs)
 | 
			
		||||
        elif model_family == 'starcoder':
 | 
			
		||||
            from bigdl.llm.ggml.model.starcoder import Starcoder
 | 
			
		||||
            return Starcoder(model_path=ggml_model_path, **kwargs)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -87,7 +87,8 @@ __all__ = ['Params',
 | 
			
		|||
           'load_vocab',
 | 
			
		||||
           'default_outfile',
 | 
			
		||||
           '_convert_gptneox_hf_to_ggml',
 | 
			
		||||
           '_convert_bloom_hf_to_ggml']
 | 
			
		||||
           '_convert_bloom_hf_to_ggml',
 | 
			
		||||
           '_convert_starcoder_hf_to_ggml']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass(frozen=True)
 | 
			
		||||
| 
						 | 
				
			
			@ -1415,3 +1416,173 @@ def _convert_bloom_hf_to_ggml(model_path, outfile_dir, outtype):
 | 
			
		|||
        data.tofile(fout)
 | 
			
		||||
 | 
			
		||||
    fout.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _convert_starcoder_hf_to_ggml(model_path, outfile_dir, outtype):
 | 
			
		||||
    from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
 | 
			
		||||
    import torch
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(model_path)
 | 
			
		||||
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
    hparams = config.to_dict()
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(model_path, config=config,
 | 
			
		||||
                                                 torch_dtype=torch.float16
 | 
			
		||||
                                                 if outtype == "f16" else torch.float32,
 | 
			
		||||
                                                 low_cpu_mem_usage=True,
 | 
			
		||||
                                                 trust_remote_code=True,
 | 
			
		||||
                                                 offload_state_dict=True)
 | 
			
		||||
 | 
			
		||||
    list_vars = model.state_dict()
 | 
			
		||||
 | 
			
		||||
    encoder = tokenizer.vocab
 | 
			
		||||
    # Add added_tokens (special tokens) to the encoder
 | 
			
		||||
    encoder.update(tokenizer.get_added_vocab())
 | 
			
		||||
 | 
			
		||||
    filestem = Path(model_path).stem
 | 
			
		||||
    fn_out = os.path.join(outfile_dir, f"ggml-{filestem}-{outtype}.bin")
 | 
			
		||||
    fout = open(fn_out, "wb")
 | 
			
		||||
 | 
			
		||||
    if outtype == "f16":
 | 
			
		||||
        ftype = 1
 | 
			
		||||
    else:
 | 
			
		||||
        ftype = 0
 | 
			
		||||
 | 
			
		||||
    fout.write(struct.pack("i", 0x67676d6c))  # magic: ggml in hex
 | 
			
		||||
    vocab_size = hparams["vocab_size"]
 | 
			
		||||
    fout.write(struct.pack("i", vocab_size))
 | 
			
		||||
    # fout.write(struct.pack("i", len(encoder)))
 | 
			
		||||
    fout.write(struct.pack("i", hparams["n_positions"]))
 | 
			
		||||
    fout.write(struct.pack("i", hparams["n_embd"]))
 | 
			
		||||
    fout.write(struct.pack("i", hparams["n_head"]))
 | 
			
		||||
    fout.write(struct.pack("i", hparams["n_layer"]))
 | 
			
		||||
    fout.write(struct.pack("i", ftype))
 | 
			
		||||
 | 
			
		||||
    byte_encoder = bytes_to_unicode()
 | 
			
		||||
    byte_decoder = {v: k for k, v in byte_encoder.items()}
 | 
			
		||||
 | 
			
		||||
    fout.write(struct.pack("i", vocab_size))
 | 
			
		||||
 | 
			
		||||
    counter = 0
 | 
			
		||||
    # sort by value
 | 
			
		||||
    for key in sorted(encoder, key=encoder.get):
 | 
			
		||||
        text = bytearray([byte_decoder[c] for c in key])
 | 
			
		||||
        fout.write(struct.pack("i", len(text)))
 | 
			
		||||
        fout.write(text)
 | 
			
		||||
        counter += 1
 | 
			
		||||
 | 
			
		||||
    # TODO: Repeat last token until vocab_size
 | 
			
		||||
    while counter < vocab_size:
 | 
			
		||||
        fout.write(struct.pack("i", len(text)))
 | 
			
		||||
        fout.write(text)
 | 
			
		||||
        counter += 1
 | 
			
		||||
 | 
			
		||||
    for name in list_vars.keys():
 | 
			
		||||
        data = list_vars[name].squeeze().numpy()
 | 
			
		||||
        print("Processing variable: " + name + " with shape: ", data.shape)
 | 
			
		||||
 | 
			
		||||
        # rename headers to keep compatibility
 | 
			
		||||
        if name == "transformer.ln_f.weight":
 | 
			
		||||
            name = "model/ln_f/g"
 | 
			
		||||
        elif name == "transformer.ln_f.bias":
 | 
			
		||||
            name = "model/ln_f/b"
 | 
			
		||||
        elif name == "transformer.wte.weight":
 | 
			
		||||
            name = "model/wte"
 | 
			
		||||
        elif name == "transformer.wpe.weight":
 | 
			
		||||
            name = "model/wpe"
 | 
			
		||||
        elif name == "lm_head.weight":
 | 
			
		||||
            name = "model/lm_head"
 | 
			
		||||
        elif re.match(r"transformer.h\.\d+\.ln_1\.weight", name):
 | 
			
		||||
            i = re.findall("\d+", name)[0]
 | 
			
		||||
            name = f"model/h{i}/ln_1/g"
 | 
			
		||||
        elif re.match(r"transformer.h\.\d+\.ln_1\.bias", name):
 | 
			
		||||
            i = re.findall("\d+", name)[0]
 | 
			
		||||
            name = f"model/h{i}/ln_1/b"
 | 
			
		||||
        elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.weight", name):
 | 
			
		||||
            i = re.findall("\d+", name)[0]
 | 
			
		||||
            name = f"model/h{i}/attn/c_attn/w"
 | 
			
		||||
        elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.bias", name):
 | 
			
		||||
            i = re.findall("\d+", name)[0]
 | 
			
		||||
            name = f"model/h{i}/attn/c_attn/b"
 | 
			
		||||
        elif re.match(r"transformer.h\.\d+\.attn\.c_proj\.weight", name):
 | 
			
		||||
            i = re.findall("\d+", name)[0]
 | 
			
		||||
            name = f"model/h{i}/attn/c_proj/w"
 | 
			
		||||
        elif re.match(r"transformer.h.\d+.attn.c_proj.bias", name):
 | 
			
		||||
            i = re.findall("\d+", name)[0]
 | 
			
		||||
            name = f"model/h{i}/attn/c_proj/b"
 | 
			
		||||
        elif re.match(r"transformer.h.\d+.ln_2.weight", name):
 | 
			
		||||
            i = re.findall("\d+", name)[0]
 | 
			
		||||
            name = f"model/h{i}/ln_2/g"
 | 
			
		||||
        elif re.match(r"transformer.h.\d+.ln_2.bias", name):
 | 
			
		||||
            i = re.findall("\d+", name)[0]
 | 
			
		||||
            name = f"model/h{i}/ln_2/b"
 | 
			
		||||
        elif re.match(r"transformer.h.\d+.mlp.c_fc.weight", name):
 | 
			
		||||
            i = re.findall("\d+", name)[0]
 | 
			
		||||
            name = f"model/h{i}/mlp/c_fc/w"
 | 
			
		||||
        elif re.match(r"transformer.h.\d+.mlp.c_fc.bias", name):
 | 
			
		||||
            i = re.findall("\d+", name)[0]
 | 
			
		||||
            name = f"model/h{i}/mlp/c_fc/b"
 | 
			
		||||
        elif re.match(r"transformer.h.\d+.mlp.c_proj.weight", name):
 | 
			
		||||
            i = re.findall("\d+", name)[0]
 | 
			
		||||
            name = f"model/h{i}/mlp/c_proj/w"
 | 
			
		||||
        elif re.match(r"transformer.h.\d+.mlp.c_proj.bias", name):
 | 
			
		||||
            i = re.findall("\d+", name)[0]
 | 
			
		||||
            name = f"model/h{i}/mlp/c_proj/b"
 | 
			
		||||
        else:
 | 
			
		||||
            print("Unrecognized variable name. %s", name)
 | 
			
		||||
 | 
			
		||||
        # we don't need these
 | 
			
		||||
        if name.endswith("attn.masked_bias") or name.endswith(".attn.bias"):
 | 
			
		||||
            print("  Skipping variable: " + name)
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        n_dims = len(data.shape)
 | 
			
		||||
 | 
			
		||||
        ftype_cur = 0
 | 
			
		||||
        if ftype == 1:
 | 
			
		||||
            if (name == "model/wte" or name == "model/lm_head" or name[-2:] == "/g" or
 | 
			
		||||
                    name[-2:] == "/w") and n_dims == 2:
 | 
			
		||||
                print("  Converting to float16")
 | 
			
		||||
                data = data.astype(np.float16)
 | 
			
		||||
                ftype_cur = 1
 | 
			
		||||
            else:
 | 
			
		||||
                print("  Converting to float32")
 | 
			
		||||
                data = data.astype(np.float32)
 | 
			
		||||
                ftype_cur = 0
 | 
			
		||||
 | 
			
		||||
        "model/h.*/attn/c_attn/w"
 | 
			
		||||
        "model/h.*/attn/c_proj/w"
 | 
			
		||||
        "model/h.*/mlp/c_fc/w"
 | 
			
		||||
        "model/h.*/mlp/c_proj/w"
 | 
			
		||||
        if name[-14:] == "/attn/c_attn/w" or name[-14:] == "/attn/c_attn/b":
 | 
			
		||||
            print("  Duplicate K,V heads to use MHA instead of MQA")
 | 
			
		||||
 | 
			
		||||
            embed_dim = hparams["n_embd"]
 | 
			
		||||
            head_dim = embed_dim // hparams["n_head"]
 | 
			
		||||
 | 
			
		||||
            # ((n_heads + 2) * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim)
 | 
			
		||||
            q, k, v = np.split(data,
 | 
			
		||||
                               (hparams["n_head"] * head_dim,
 | 
			
		||||
                                (hparams["n_head"] + 1) * head_dim),
 | 
			
		||||
                               axis=0)
 | 
			
		||||
            # duplicate k, v along the first axis (head_dim, hidden_dim) ->
 | 
			
		||||
            # (n_heads * head_dim, hidden_dim)
 | 
			
		||||
            if len(k.shape) == 2:
 | 
			
		||||
                k = np.tile(k, (hparams["n_head"], 1))
 | 
			
		||||
                v = np.tile(v, (hparams["n_head"], 1))
 | 
			
		||||
            elif len(k.shape) == 1:
 | 
			
		||||
                k = np.tile(k, (hparams["n_head"]))
 | 
			
		||||
                v = np.tile(v, (hparams["n_head"]))
 | 
			
		||||
            # concat q, k, v along the first axis (n_heads * head_dim, hidden_dim) ->
 | 
			
		||||
            # (3 * n_heads * head_dim, hidden_dim)
 | 
			
		||||
            data = np.concatenate((q, k, v), axis=0)
 | 
			
		||||
 | 
			
		||||
        # header
 | 
			
		||||
        str = name.encode('utf-8')
 | 
			
		||||
        fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
 | 
			
		||||
        for i in range(n_dims):
 | 
			
		||||
            fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
 | 
			
		||||
        fout.write(str)
 | 
			
		||||
 | 
			
		||||
        # data
 | 
			
		||||
        data.tofile(fout)
 | 
			
		||||
 | 
			
		||||
    fout.close()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue