support loading llama tokenizer from gguf model (#9565)
This commit is contained in:
parent
2554ba0913
commit
7f6465518a
4 changed files with 49 additions and 8 deletions
|
|
@ -41,8 +41,8 @@ def load_gguf_model(fpath: str, dtype: torch.dtype = torch.float):
|
|||
if model_family == "llama":
|
||||
from .models.llama import load_gguf_llama
|
||||
|
||||
model = load_gguf_llama(loader, dtype)
|
||||
model, tokenizer = load_gguf_llama(loader, dtype)
|
||||
else:
|
||||
invalidInputError(False, f"Unsupported model family: {model_family}")
|
||||
|
||||
return model, low_bit
|
||||
return model, tokenizer, low_bit
|
||||
|
|
|
|||
|
|
@ -372,3 +372,25 @@ class GGUFFileLoader:
|
|||
|
||||
def tensors_iter(self):
|
||||
return self.tensor_loader
|
||||
|
||||
def tokenizer_pieces(self):
|
||||
from transformers.convert_slow_tokenizer import import_protobuf
|
||||
spm_pb2 = import_protobuf("Failed to import protobuf")
|
||||
|
||||
tokens = self.config['tokenizer.ggml.tokens']
|
||||
scores = self.config['tokenizer.ggml.scores']
|
||||
token_types = self.config['tokenizer.ggml.token_type']
|
||||
|
||||
pieces = [
|
||||
spm_pb2.ModelProto.SentencePiece(
|
||||
piece=token,
|
||||
score=score,
|
||||
type=token_type,
|
||||
)
|
||||
for token, score, token_type in tqdm(
|
||||
zip(tokens, scores, token_types),
|
||||
"Loading gguf vocab"
|
||||
)
|
||||
]
|
||||
|
||||
return pieces
|
||||
|
|
|
|||
|
|
@ -14,10 +14,12 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
from tempfile import NamedTemporaryFile
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
from ..gguf import GGUFFileLoader
|
||||
|
||||
|
|
@ -79,12 +81,29 @@ def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
|||
|
||||
model = model.cpu()
|
||||
|
||||
return model
|
||||
# see https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
|
||||
from transformers.convert_slow_tokenizer import import_protobuf
|
||||
spm_pb2 = import_protobuf("Failed to import protobuf")
|
||||
|
||||
pieces = loader.tokenizer_pieces()
|
||||
trainer_spec = spm_pb2.TrainerSpec(byte_fallback=True,
|
||||
model_type=spm_pb2.TrainerSpec.ModelType.BPE)
|
||||
proto = spm_pb2.ModelProto(pieces=pieces, trainer_spec=trainer_spec)
|
||||
proto = proto.SerializeToString()
|
||||
|
||||
with NamedTemporaryFile(delete=False) as f:
|
||||
f.write(proto)
|
||||
f.close()
|
||||
tokenizer = LlamaTokenizer(f.name)
|
||||
os.remove(f.name)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def restore_llama_weight(ckpt: dict, n_head: int, n_head_kv: int):
|
||||
# see https://github.com/ggerganov/llama.cpp/blob
|
||||
# /3e73d31d9cc0232882ce61c64742aff3ecfec416/convert.py#L978
|
||||
|
||||
for name, weight in ckpt.items():
|
||||
head, hd_size = weight.shape[0], weight.shape[1:]
|
||||
if name.endswith("attn_q.weight"):
|
||||
|
|
|
|||
|
|
@ -194,21 +194,21 @@ class _BaseAutoModelClass:
|
|||
@staticmethod
|
||||
def from_gguf(fpath: str, optimize_model: bool = True, cpu_embedding: bool = False):
|
||||
"""
|
||||
Load a gguf model and convert it to bigdl-llm model
|
||||
Load gguf model and tokenizer and convert it to bigdl-llm model and huggingface tokenzier
|
||||
|
||||
:param fpath: Path to gguf model file
|
||||
:param optimize_model: Whether to further optimize llm model, defaults to True
|
||||
:param cpu_embedding: Whether to replace the Embedding layer, may need to set it
|
||||
to `True` when running BigDL-LLM on GPU on Windows, defaults to False
|
||||
|
||||
:return: An optimized bigdl-llm model
|
||||
:return: An optimized bigdl-llm model and a huggingface tokenizer
|
||||
"""
|
||||
from bigdl.llm.optimize import optimize_model as optimize_model_fn
|
||||
|
||||
model, low_bit = load_gguf_model(fpath, dtype=torch.half)
|
||||
model, tokenizer, low_bit = load_gguf_model(fpath, dtype=torch.half)
|
||||
model = optimize_model_fn(model, low_bit=low_bit, optimize_llm=optimize_model,
|
||||
cpu_embedding=cpu_embedding)
|
||||
return model
|
||||
return model, tokenizer
|
||||
|
||||
@classmethod
|
||||
def load_convert(cls, q_k, optimize_model, *args, **kwargs):
|
||||
|
|
|
|||
Loading…
Reference in a new issue