support loading llama tokenizer from gguf model (#9565)

This commit is contained in:
Yishuo Wang 2023-11-30 14:56:12 +08:00 committed by GitHub
parent 2554ba0913
commit 7f6465518a
4 changed files with 49 additions and 8 deletions

View file

@ -41,8 +41,8 @@ def load_gguf_model(fpath: str, dtype: torch.dtype = torch.float):
if model_family == "llama": if model_family == "llama":
from .models.llama import load_gguf_llama from .models.llama import load_gguf_llama
model = load_gguf_llama(loader, dtype) model, tokenizer = load_gguf_llama(loader, dtype)
else: else:
invalidInputError(False, f"Unsupported model family: {model_family}") invalidInputError(False, f"Unsupported model family: {model_family}")
return model, low_bit return model, tokenizer, low_bit

View file

@ -372,3 +372,25 @@ class GGUFFileLoader:
def tensors_iter(self): def tensors_iter(self):
return self.tensor_loader return self.tensor_loader
def tokenizer_pieces(self):
from transformers.convert_slow_tokenizer import import_protobuf
spm_pb2 = import_protobuf("Failed to import protobuf")
tokens = self.config['tokenizer.ggml.tokens']
scores = self.config['tokenizer.ggml.scores']
token_types = self.config['tokenizer.ggml.token_type']
pieces = [
spm_pb2.ModelProto.SentencePiece(
piece=token,
score=score,
type=token_type,
)
for token, score, token_type in tqdm(
zip(tokens, scores, token_types),
"Loading gguf vocab"
)
]
return pieces

View file

@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
# #
import os
import torch import torch
from accelerate import init_empty_weights from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device from accelerate.utils import set_module_tensor_to_device
from transformers import LlamaConfig, LlamaForCausalLM from tempfile import NamedTemporaryFile
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
from ..gguf import GGUFFileLoader from ..gguf import GGUFFileLoader
@ -79,12 +81,29 @@ def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
model = model.cpu() model = model.cpu()
return model # see https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
from transformers.convert_slow_tokenizer import import_protobuf
spm_pb2 = import_protobuf("Failed to import protobuf")
pieces = loader.tokenizer_pieces()
trainer_spec = spm_pb2.TrainerSpec(byte_fallback=True,
model_type=spm_pb2.TrainerSpec.ModelType.BPE)
proto = spm_pb2.ModelProto(pieces=pieces, trainer_spec=trainer_spec)
proto = proto.SerializeToString()
with NamedTemporaryFile(delete=False) as f:
f.write(proto)
f.close()
tokenizer = LlamaTokenizer(f.name)
os.remove(f.name)
return model, tokenizer
def restore_llama_weight(ckpt: dict, n_head: int, n_head_kv: int): def restore_llama_weight(ckpt: dict, n_head: int, n_head_kv: int):
# see https://github.com/ggerganov/llama.cpp/blob # see https://github.com/ggerganov/llama.cpp/blob
# /3e73d31d9cc0232882ce61c64742aff3ecfec416/convert.py#L978 # /3e73d31d9cc0232882ce61c64742aff3ecfec416/convert.py#L978
for name, weight in ckpt.items(): for name, weight in ckpt.items():
head, hd_size = weight.shape[0], weight.shape[1:] head, hd_size = weight.shape[0], weight.shape[1:]
if name.endswith("attn_q.weight"): if name.endswith("attn_q.weight"):

View file

@ -194,21 +194,21 @@ class _BaseAutoModelClass:
@staticmethod @staticmethod
def from_gguf(fpath: str, optimize_model: bool = True, cpu_embedding: bool = False): def from_gguf(fpath: str, optimize_model: bool = True, cpu_embedding: bool = False):
""" """
Load a gguf model and convert it to bigdl-llm model Load gguf model and tokenizer and convert it to bigdl-llm model and huggingface tokenzier
:param fpath: Path to gguf model file :param fpath: Path to gguf model file
:param optimize_model: Whether to further optimize llm model, defaults to True :param optimize_model: Whether to further optimize llm model, defaults to True
:param cpu_embedding: Whether to replace the Embedding layer, may need to set it :param cpu_embedding: Whether to replace the Embedding layer, may need to set it
to `True` when running BigDL-LLM on GPU on Windows, defaults to False to `True` when running BigDL-LLM on GPU on Windows, defaults to False
:return: An optimized bigdl-llm model :return: An optimized bigdl-llm model and a huggingface tokenizer
""" """
from bigdl.llm.optimize import optimize_model as optimize_model_fn from bigdl.llm.optimize import optimize_model as optimize_model_fn
model, low_bit = load_gguf_model(fpath, dtype=torch.half) model, tokenizer, low_bit = load_gguf_model(fpath, dtype=torch.half)
model = optimize_model_fn(model, low_bit=low_bit, optimize_llm=optimize_model, model = optimize_model_fn(model, low_bit=low_bit, optimize_llm=optimize_model,
cpu_embedding=cpu_embedding) cpu_embedding=cpu_embedding)
return model return model, tokenizer
@classmethod @classmethod
def load_convert(cls, q_k, optimize_model, *args, **kwargs): def load_convert(cls, q_k, optimize_model, *args, **kwargs):