From 548f7a6cf72112fc44d1346a01de9c0851ab0097 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Fri, 18 Aug 2023 09:30:35 +0800 Subject: [PATCH] LLM: update convert of llama family to support llama2-70B (#8747) --- python/llm/src/bigdl/llm/ggml/convert.py | 10 +- .../llm/src/bigdl/llm/utils/convert_util.py | 333 ++++++++++++++---- 2 files changed, 265 insertions(+), 78 deletions(-) diff --git a/python/llm/src/bigdl/llm/ggml/convert.py b/python/llm/src/bigdl/llm/ggml/convert.py index a332caf1..0ab511ba 100644 --- a/python/llm/src/bigdl/llm/ggml/convert.py +++ b/python/llm/src/bigdl/llm/ggml/convert.py @@ -54,14 +54,14 @@ def _convert_llama(model_path, outfile_dir, outtype): vocab = model_plus.vocab else: vocab_dir = model_plus.paths[0].parent - vocab = load_vocab(vocab_dir) + vocab = load_vocab(vocab_dir, vocabtype='spm') + params = Params.load(model_plus) model = model_plus.model - model = do_necessary_conversions(model) + model = do_necessary_conversions(model, params) output_type = pick_output_type(model, outtype) model = convert_to_output_type(model, output_type) - params = Params.guessed(model, output_type) - outfile_path = default_outfile(outfile_dir, params) - OutputFile.write_all(outfile_path, params, model, vocab) + outfile_path = default_outfile([outfile_dir], output_type) + OutputFile.write_all(outfile_path, params, output_type, model, vocab) def _convert_gptneox(model_path, outfile_dir, outtype): diff --git a/python/llm/src/bigdl/llm/utils/convert_util.py b/python/llm/src/bigdl/llm/utils/convert_util.py index 8acfd541..bd223587 100644 --- a/python/llm/src/bigdl/llm/utils/convert_util.py +++ b/python/llm/src/bigdl/llm/utils/convert_util.py @@ -188,39 +188,141 @@ TENSORS_LIST = make_tensors_list() TENSORS_SET = set(TENSORS_LIST) +def find_n_mult(n_ff: int, n_embd: int) -> int: + # hardcoded magic range + for n_mult in range(8192, 1, -1): + calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult + if calc_ff == n_ff: + return n_mult + invalidInputError(False, + f"Failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).") + + @dataclass class Params: - n_vocab: int - n_embd: int - n_mult: int - n_head: int - n_layer: int - file_type: GGMLFileType + n_vocab: int + n_embd: int + n_mult: int + n_head: int + n_layer: int + n_kv_head: Optional[int] # This parameter is only used for Llama 2 @staticmethod - def guessed(model: 'LazyModel', file_type: GGMLFileType) -> 'Params': - n_vocab, n_embd = model["tok_embeddings.weight"].shape + def guessed(model: 'LazyModel') -> 'Params': + # try transformer naming first + if "model.embed_tokens.weight" in model: + n_vocab, n_embd = model["model.embed_tokens.weight"].shape + else: + n_vocab, n_embd = model["tok_embeddings.weight"].shape + + # try transformer naming first + if "model.layers.0.self_attn.q_proj.weight" in model: + n_layer = next(i for i in itertools.count() + if f"model.layers.{i}.self_attn.q_proj.weight" not in model) + elif "model.layers.0.self_attn.W_pack.weight" in model: # next: try baichuan naming + n_layer = next(i for i in itertools.count() + if f"model.layers.{i}.self_attn.W_pack.weight" not in model) + else: + n_layer = next(i for i in itertools.count() + if f"layers.{i}.attention.wq.weight" not in model) + + if n_layer < 1: + invalidInputError(False, "Failed to guess 'n_layer'. This model is unknown or " + "unsupported.\nSuggestion: provide 'config.json' of the " + "model in the same directory containing model files.") + + n_head = n_embd // 128 # guessed return Params( n_vocab=n_vocab, n_embd=n_embd, n_mult=256, - n_head=n_embd // 128, - n_layer=next(i for i in itertools.count() - if f"layers.{i}.attention.wq.weight" not in model), - file_type=file_type, + n_head=n_head, + n_layer=n_layer, + n_kv_head=None, ) + @staticmethod + def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params': + config = json.load(open(config_path)) + + n_vocab = config["vocab_size"] + n_embd = config["hidden_size"] + n_head = config["num_attention_heads"] + n_layer = config["num_hidden_layers"] + n_ff = config["intermediate_size"] + n_kv_head = config.get("num_key_value_heads") + + n_mult = find_n_mult(n_ff, n_embd) + + return Params( + n_vocab=n_vocab, + n_embd=n_embd, + n_mult=n_mult, + n_head=n_head, + n_layer=n_layer, + n_kv_head=n_kv_head, + ) + + # LLaMA v2 70B params.json + # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, + # "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1} + @staticmethod + def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params': + config = json.load(open(config_path)) + + n_vocab = config["vocab_size"] + n_embd = config["dim"] + n_head = config["n_heads"] + n_layer = config["n_layers"] + n_mult = config["multiple_of"] + + if n_vocab == -1: + n_vocab = model["tok_embeddings.weight"].shape[0] + + return Params( + n_vocab=n_vocab, + n_embd=n_embd, + n_mult=n_mult, + n_head=n_head, + n_layer=n_layer, + n_kv_head=None, + ) + + @staticmethod + def load(model_plus: 'ModelPlus') -> 'Params': + hf_config_path = model_plus.paths[0].parent / "config.json" + orig_config_path = model_plus.paths[0].parent / "params.json" + + if hf_config_path.exists(): + params = Params.loadHFTransformerJson(model_plus.model, hf_config_path) + elif orig_config_path.exists(): + params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path) + else: + params = Params.guessed(model_plus.model) + + print(f'params: n_vocab:{params.n_vocab} n_embd:{params.n_embd}' + f'n_mult:{params.n_mult} n_head:{params.n_head} n_layer:{params.n_layer}') + return params + class SentencePieceVocab: - def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None: - self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], + vocabtype: Optional[str]) -> None: + self.vocabtype = vocabtype + if self.vocabtype == "bpe": + self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read()) + else: + self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) added_tokens = Dict[str, int] if fname_added_tokens is not None: added_tokens = json.load(open(fname_added_tokens)) else: added_tokens = {} - vocab_size = self.sentencepiece_tokenizer.vocab_size() + if self.vocabtype == "bpe": + vocab_size: int = len(self.sentencepiece_tokenizer) + else: + vocab_size: int = self.sentencepiece_tokenizer.vocab_size() expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) actual_ids = sorted(added_tokens.values()) invalidInputError(expected_ids == actual_ids, @@ -235,22 +337,33 @@ class SentencePieceVocab: def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]: tokenizer = self.sentencepiece_tokenizer - for i in range(tokenizer.vocab_size()): - text = bytes - if tokenizer.is_unknown(i): - text = " \u2047 ".encode("utf-8") - elif tokenizer.is_control(i): - text = b"" - elif tokenizer.is_byte(i): - piece = tokenizer.id_to_piece(i) - invalidInputError(len(piece) == 6, - f"Invalid token: {piece}") - byte_value = int(piece[3:-1], 16) - text = struct.pack("B", byte_value) - else: - text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") - score = tokenizer.get_score(i) - yield text, score + if self.vocabtype == "bpe": + from transformers.models.gpt2 import tokenization_gpt2 + byte_encoder = tokenization_gpt2.bytes_to_unicode() + byte_decoder = {v: k for k, v in byte_encoder.items()} + for i, item in enumerate(tokenizer): + text: bytes + text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] + for y in item]]) + score: float = -i + yield text, score + else: + for i in range(tokenizer.vocab_size()): + text: bytes + if tokenizer.is_unknown(i): + text = " \u2047 ".encode("utf-8") + elif tokenizer.is_control(i): + text = b"" + elif tokenizer.is_byte(i): + piece = tokenizer.id_to_piece(i) + if len(piece) != 6: + invalidInputError(False, f"Invalid token: {piece}") + byte_value = int(piece[3:-1], 16) + text = struct.pack("B", byte_value) + else: + text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") + score: float = tokenizer.get_score(i) + yield text, score def added_tokens(self) -> Iterable[Tuple[bytes, float]]: for text in self.added_tokens_list: @@ -281,7 +394,9 @@ class GGMLVocab: Vocab = Union[SentencePieceVocab, GGMLVocab] -def permute(weights: NDArray, n_head: int) -> NDArray: +def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray: + if n_kv_head is not None and n_head != n_kv_head: + n_head //= n_kv_head return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) .swapaxes(1, 2) .reshape(weights.shape)) @@ -338,7 +453,15 @@ class Tensor(metaclass=ABCMeta): pass @abstractmethod - def permute(self, n_head: int) -> 'Tensor': + def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'Tensor': + pass + + @abstractmethod + def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': + pass + + @abstractmethod + def part(self, n_part: int) -> 'UnquantizedTensor': pass @abstractmethod @@ -367,8 +490,16 @@ class UnquantizedTensor(Tensor): def to_ggml(self) -> 'UnquantizedTensor': return self - def permute(self, n_head: int) -> 'UnquantizedTensor': - return UnquantizedTensor(permute(self.ndarray, n_head)) + def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(permute(self.ndarray[r * n_part: r * n_part + r, ...], n_head)) + + def part(self, n_part: int) -> 'UnquantizedTensor': + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(self.ndarray[r * n_part: r * n_part + r, ...]) + + def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'UnquantizedTensor': + return UnquantizedTensor(permute(self.ndarray, n_head, n_kv_head)) def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, @@ -417,26 +548,36 @@ class GGMLQuantizedTensor(Tensor): def to_ggml(self) -> 'GGMLQuantizedTensor': return self - def permute(self, n_head: int) -> 'GGMLQuantizedTensor': - return GGMLQuantizedTensor(permute(self.ndarray, n_head), self.shape, self.data_type) + def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'GGMLQuantizedTensor': + return GGMLQuantizedTensor(permute(self.ndarray, n_head, n_kv_head), + self.shape, self.data_type) + + def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(permute(self.ndarray[r * n_part: r * n_part + r, ...], n_head)) + + def part(self, n_part: int) -> 'UnquantizedTensor': + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(self.ndarray[r * n_part: r * n_part + r, ...]) GGMLCompatibleTensor = Union[UnquantizedTensor, GGMLQuantizedTensor] class DeferredPermutedTensor(Tensor): - def __init__(self, base: Tensor, n_head: int) -> None: + def __init__(self, base: Tensor, n_head: int, n_kv_head: Optional[int] = None) -> None: self.base = base self.n_head = n_head + self.n_kv_head = n_kv_head self.data_type = self.base.data_type def astype(self, data_type: DataType) -> Tensor: - return self.base.astype(data_type).permute(self.n_head) + return self.base.astype(data_type).permute(self.n_head, self.n_kv_head) def to_ggml(self) -> GGMLCompatibleTensor: - return self.base.to_ggml().permute(self.n_head) + return self.base.to_ggml().permute(self.n_head, self.n_kv_head) - def permute(self, n_head: int) -> Tensor: + def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> Tensor: invalidInputError(False, "Shouldn't permute twice.") @@ -540,8 +681,8 @@ class GPTQForLLaMaQuantizedTensor(Tensor): have_g_idx=False) return ret - def permute(self, n_head: int) -> Tensor: - return DeferredPermutedTensor(self, n_head) + def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> Tensor: + return DeferredPermutedTensor(self, n_head, n_kv_head) def to_ggml(self) -> GGMLQuantizedTensor: # The output format looks like this: @@ -598,8 +739,11 @@ class LazyTensor: "Can't turn an unquantized tensor into" f" a quantized type ({data_type}).") if self.data_type.have_g_idx: - sys.stderr.write("Error: Input uses the newer GPTQ-for-LLaMa format (using g_idx)" - ", which is not yet natively supported by GGML.") + sys.stderr.write( + "Error: Input uses the newer GPTQ-for-LLaMa format (using g_idx), " + "which is not yet natively supported by GGML. For now " + "you can still convert this model by passing `--outtype f16` to dequantize, " + "but that will result in a much larger output file for no quality benefit.\n") sys.exit(1) invalidInputError(not data_type.have_g_idx and self.data_type.have_addends and data_type.have_addends, @@ -675,28 +819,57 @@ def merge_multifile_models(models_plus: List[ModelPlus]) -> ModelPlus: return ModelPlus(model, paths, format, vocab) -def permute_lazy(lazy_tensor: LazyTensor, n_head: int) -> LazyTensor: +def permute_lazy(lazy_tensor: LazyTensor, n_head: int, + n_kv_head: Optional[int] = None) -> LazyTensor: def load() -> Tensor: - return lazy_tensor.load().permute(n_head) + return lazy_tensor.load().permute(n_head, n_kv_head) return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, + f'permute({n_head}, {n_kv_head}) ' + lazy_tensor.description) + + +def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().permute_part(n_part, n_head) + s = lazy_tensor.shape.copy() + s[0] = s[0] // 3 + return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description) -def convert_transformers_to_orig(model: LazyModel) -> LazyModel: +def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().part(n_part) + s = lazy_tensor.shape.copy() + s[0] = s[0] // 3 + return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description) + + +def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel: out = {} out["tok_embeddings.weight"] = model["model.embed_tokens.weight"] out["norm.weight"] = model["model.norm.weight"] out["output.weight"] = model["lm_head.weight"] - n_head = model["model.layers.0.self_attn.q_proj.weight"].shape[1] // 128 for i in itertools.count(): - if f"model.layers.{i}.self_attn.q_proj.weight" not in model: + if f"model.layers.{i}.self_attn.q_proj.weight" in model: + out[f"layers.{i}.attention.wq.weight"] = \ + permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head) + out[f"layers.{i}.attention.wk.weight"] = \ + permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], + params.n_head, params.n_kv_head) + out[f"layers.{i}.attention.wv.weight"] = \ + model[f"model.layers.{i}.self_attn.v_proj.weight"] + elif f"model.layers.{i}.self_attn.W_pack.weight" in model: + out[f"layers.{i}.attention.wq.weight"] = \ + permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], + 0, params.n_head) + out[f"layers.{i}.attention.wk.weight"] = \ + permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], + 1, params.n_head) + out[f"layers.{i}.attention.wv.weight"] = \ + part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 2) + else: break - out[f"layers.{i}.attention.wq.weight"] = \ - permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], n_head) - out[f"layers.{i}.attention.wk.weight"] = \ - permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], n_head) - out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] out[f"layers.{i}.attention.wo.weight"] = model[f"model.layers.{i}.self_attn.o_proj.weight"] out[f"layers.{i}.feed_forward.w1.weight"] = model[f"model.layers.{i}.mlp.gate_proj.weight"] @@ -792,7 +965,9 @@ class LazyUnpickler(pickle.Unpickler): f' path={self.zip_file.filename}' return LazyStorage(load=load, kind=pid[1], description=description) + # @staticmethod def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, + # pyright: ignore[reportSelfClsParameterName] requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: invalidInputError(isinstance(storage, LazyStorage), "Fail to rebuild `LazyTensor`.") @@ -837,6 +1012,7 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus: SAFETENSORS_DATA_TYPES = { + 'BF16': DT_BF16, 'F16': DT_F16, 'F32': DT_F32, 'I32': DT_I32, @@ -1020,7 +1196,7 @@ class OutputFile: def __init__(self, fname_out: Path) -> None: self.fout = open(fname_out, "wb") - def write_file_header(self, params: Params) -> None: + def write_file_header(self, params: Params, file_type: GGMLFileType) -> None: self.fout.write(b"ggjt"[::-1]) # magic values = [ 1, # file version @@ -1030,7 +1206,7 @@ class OutputFile: params.n_head, params.n_layer, params.n_embd // params.n_head, # rot (obsolete) - params.file_type.value, + file_type.value, ] self.fout.write(struct.pack("i" * len(values), *values)) @@ -1050,18 +1226,18 @@ class OutputFile: @staticmethod def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: of = OutputFile(fname_out) - params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, - n_head=1, n_layer=0, file_type=GGMLFileType.AllF32) + params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0) of = OutputFile(fname_out) - of.write_file_header(params) + of.write_file_header(params, file_type=GGMLFileType.AllF32) of.write_vocab(vocab) of.fout.close() @staticmethod - def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None: + def write_all(fname_out: Path, params: Params, file_type: GGMLFileType, model: LazyModel, + vocab: Vocab) -> None: check_vocab_size(params, vocab) of = OutputFile(fname_out) - of.write_file_header(params) + of.write_file_header(params, file_type) print("Writing vocab...") of.write_vocab(vocab) @@ -1099,11 +1275,11 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi invalidInputError(False, f"Unexpected combination of types: {name_to_type}.") -def do_necessary_conversions(model: LazyModel) -> LazyModel: +def do_necessary_conversions(model: LazyModel, params: Params) -> LazyModel: model = handle_quantization(model) if "lm_head.weight" in model: - model = convert_transformers_to_orig(model) + model = convert_transformers_to_orig(model, params) model = filter_and_sort_tensors(model) return model @@ -1157,7 +1333,7 @@ def load_some_model(path: Path) -> ModelPlus: '''Load a model of any supported format.''' # Be extra-friendly and accept either a file or a directory: if path.is_dir(): - globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"] + globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"] files = [file for glob in globs for file in path.glob(glob)] if not files: # Try GGML too, but with lower priority, since if both a non-GGML @@ -1183,35 +1359,46 @@ def filter_and_sort_tensors(model: LazyModel) -> LazyModel: return {name: model[name] for name in TENSORS_LIST if name in model} -def load_vocab(path: Path) -> SentencePieceVocab: +def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab: # Be extra-friendly and accept either a file or a directory. Also, if it's # a directory, it might be the model directory, and tokenizer.model might # be in the parent of that. + print(f"vocabtype: {vocabtype}") if path.is_dir(): - path2 = path / "tokenizer.model" + vocab_file = "tokenizer.model" + if vocabtype == 'bpe': + vocab_file = "vocab.json" + path2 = path / vocab_file # Use `.parent` instead of /.. to handle the symlink case better. - path3 = path.parent / "tokenizer.model" + path3 = path.parent / vocab_file if path2.exists(): path = path2 elif path3.exists(): path = path3 else: invalidInputError(False, - f"Could not find tokenizer.model in {path} or its parent.") + f"Could not find tokenizer.model in {path} or its parent; " + "if it's in another directory, pass the directory as --vocab-dir") added_tokens_path = path.parent / "added_tokens.json" print(f"Loading vocab file {path}") - return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None) + return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, + vocabtype) -def default_outfile(output_dir: Path, params: Params) -> Path: +def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path: namestr = { GGMLFileType.AllF32: "f32", GGMLFileType.MostlyF16: "f16", GGMLFileType.MostlyQ4_0: "q4_0", GGMLFileType.MostlyQ4_1: "q4_1", GGMLFileType.PerLayerIsQ4_1: "q4_1", - }[params.file_type] - ret = output_dir / f"ggml-model-{namestr}.bin" + }[file_type] + ret = model_paths[0] / f"ggml-model-{namestr}.bin" + if ret in model_paths: + sys.stderr.write( + f"Error: Default output path ({ret}) would overwrite the input. " + "Please explicitly specify a path using --outfile.\n") + sys.exit(1) return ret