support loading llama tokenizer from gguf model (#9565)
This commit is contained in:
parent
2554ba0913
commit
7f6465518a
4 changed files with 49 additions and 8 deletions
|
|
@ -41,8 +41,8 @@ def load_gguf_model(fpath: str, dtype: torch.dtype = torch.float):
|
||||||
if model_family == "llama":
|
if model_family == "llama":
|
||||||
from .models.llama import load_gguf_llama
|
from .models.llama import load_gguf_llama
|
||||||
|
|
||||||
model = load_gguf_llama(loader, dtype)
|
model, tokenizer = load_gguf_llama(loader, dtype)
|
||||||
else:
|
else:
|
||||||
invalidInputError(False, f"Unsupported model family: {model_family}")
|
invalidInputError(False, f"Unsupported model family: {model_family}")
|
||||||
|
|
||||||
return model, low_bit
|
return model, tokenizer, low_bit
|
||||||
|
|
|
||||||
|
|
@ -372,3 +372,25 @@ class GGUFFileLoader:
|
||||||
|
|
||||||
def tensors_iter(self):
|
def tensors_iter(self):
|
||||||
return self.tensor_loader
|
return self.tensor_loader
|
||||||
|
|
||||||
|
def tokenizer_pieces(self):
|
||||||
|
from transformers.convert_slow_tokenizer import import_protobuf
|
||||||
|
spm_pb2 = import_protobuf("Failed to import protobuf")
|
||||||
|
|
||||||
|
tokens = self.config['tokenizer.ggml.tokens']
|
||||||
|
scores = self.config['tokenizer.ggml.scores']
|
||||||
|
token_types = self.config['tokenizer.ggml.token_type']
|
||||||
|
|
||||||
|
pieces = [
|
||||||
|
spm_pb2.ModelProto.SentencePiece(
|
||||||
|
piece=token,
|
||||||
|
score=score,
|
||||||
|
type=token_type,
|
||||||
|
)
|
||||||
|
for token, score, token_type in tqdm(
|
||||||
|
zip(tokens, scores, token_types),
|
||||||
|
"Loading gguf vocab"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return pieces
|
||||||
|
|
|
||||||
|
|
@ -14,10 +14,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from accelerate.utils import set_module_tensor_to_device
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM
|
from tempfile import NamedTemporaryFile
|
||||||
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
from ..gguf import GGUFFileLoader
|
from ..gguf import GGUFFileLoader
|
||||||
|
|
||||||
|
|
@ -79,12 +81,29 @@ def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
||||||
|
|
||||||
model = model.cpu()
|
model = model.cpu()
|
||||||
|
|
||||||
return model
|
# see https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
|
||||||
|
from transformers.convert_slow_tokenizer import import_protobuf
|
||||||
|
spm_pb2 = import_protobuf("Failed to import protobuf")
|
||||||
|
|
||||||
|
pieces = loader.tokenizer_pieces()
|
||||||
|
trainer_spec = spm_pb2.TrainerSpec(byte_fallback=True,
|
||||||
|
model_type=spm_pb2.TrainerSpec.ModelType.BPE)
|
||||||
|
proto = spm_pb2.ModelProto(pieces=pieces, trainer_spec=trainer_spec)
|
||||||
|
proto = proto.SerializeToString()
|
||||||
|
|
||||||
|
with NamedTemporaryFile(delete=False) as f:
|
||||||
|
f.write(proto)
|
||||||
|
f.close()
|
||||||
|
tokenizer = LlamaTokenizer(f.name)
|
||||||
|
os.remove(f.name)
|
||||||
|
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def restore_llama_weight(ckpt: dict, n_head: int, n_head_kv: int):
|
def restore_llama_weight(ckpt: dict, n_head: int, n_head_kv: int):
|
||||||
# see https://github.com/ggerganov/llama.cpp/blob
|
# see https://github.com/ggerganov/llama.cpp/blob
|
||||||
# /3e73d31d9cc0232882ce61c64742aff3ecfec416/convert.py#L978
|
# /3e73d31d9cc0232882ce61c64742aff3ecfec416/convert.py#L978
|
||||||
|
|
||||||
for name, weight in ckpt.items():
|
for name, weight in ckpt.items():
|
||||||
head, hd_size = weight.shape[0], weight.shape[1:]
|
head, hd_size = weight.shape[0], weight.shape[1:]
|
||||||
if name.endswith("attn_q.weight"):
|
if name.endswith("attn_q.weight"):
|
||||||
|
|
|
||||||
|
|
@ -194,21 +194,21 @@ class _BaseAutoModelClass:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_gguf(fpath: str, optimize_model: bool = True, cpu_embedding: bool = False):
|
def from_gguf(fpath: str, optimize_model: bool = True, cpu_embedding: bool = False):
|
||||||
"""
|
"""
|
||||||
Load a gguf model and convert it to bigdl-llm model
|
Load gguf model and tokenizer and convert it to bigdl-llm model and huggingface tokenzier
|
||||||
|
|
||||||
:param fpath: Path to gguf model file
|
:param fpath: Path to gguf model file
|
||||||
:param optimize_model: Whether to further optimize llm model, defaults to True
|
:param optimize_model: Whether to further optimize llm model, defaults to True
|
||||||
:param cpu_embedding: Whether to replace the Embedding layer, may need to set it
|
:param cpu_embedding: Whether to replace the Embedding layer, may need to set it
|
||||||
to `True` when running BigDL-LLM on GPU on Windows, defaults to False
|
to `True` when running BigDL-LLM on GPU on Windows, defaults to False
|
||||||
|
|
||||||
:return: An optimized bigdl-llm model
|
:return: An optimized bigdl-llm model and a huggingface tokenizer
|
||||||
"""
|
"""
|
||||||
from bigdl.llm.optimize import optimize_model as optimize_model_fn
|
from bigdl.llm.optimize import optimize_model as optimize_model_fn
|
||||||
|
|
||||||
model, low_bit = load_gguf_model(fpath, dtype=torch.half)
|
model, tokenizer, low_bit = load_gguf_model(fpath, dtype=torch.half)
|
||||||
model = optimize_model_fn(model, low_bit=low_bit, optimize_llm=optimize_model,
|
model = optimize_model_fn(model, low_bit=low_bit, optimize_llm=optimize_model,
|
||||||
cpu_embedding=cpu_embedding)
|
cpu_embedding=cpu_embedding)
|
||||||
return model
|
return model, tokenizer
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_convert(cls, q_k, optimize_model, *args, **kwargs):
|
def load_convert(cls, q_k, optimize_model, *args, **kwargs):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue