LLM: Add gguf falcon (#9801)

* init falcon

* update convert.py

* update style
This commit is contained in:
Wang, Jian4 2024-01-03 14:49:02 +08:00 committed by GitHub
parent 0396fafed1
commit a54cd767b1
6 changed files with 304 additions and 26 deletions

View file

@ -7,6 +7,7 @@ In this directory, you will find examples on how to load GGUF model into `bigdl-
- [Mixtral-8x7B-v0.1-GGUF](https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF)
- [Baichuan2-7B-Chat-GGUF](https://huggingface.co/second-state/Baichuan2-7B-Chat-GGUF/tree/main)
- [Bloomz-7b1-GGUF](https://huggingface.co/hzjane/bloomz-7b1-gguf)
- [falcon-7b-quantized-gguf](https://huggingface.co/xaviviro/falcon-7b-quantized-gguf/tree/main)
- [mpt-7b-chat-gguf](https://huggingface.co/maddes8cht/mosaicml-mpt-7b-chat-gguf/tree/main)
## Requirements

View file

@ -7,6 +7,7 @@ In this directory, you will find examples on how to load GGUF model into `bigdl-
- [Mixtral-8x7B-v0.1-GGUF](https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF)
- [Baichuan2-7B-Chat-GGUF](https://huggingface.co/second-state/Baichuan2-7B-Chat-GGUF/tree/main)
- [Bloomz-7b1-GGUF](https://huggingface.co/hzjane/bloomz-7b1-gguf)
- [falcon-7b-quantized-gguf](https://huggingface.co/xaviviro/falcon-7b-quantized-gguf/tree/main)
- [mpt-7b-chat-gguf](https://huggingface.co/maddes8cht/mosaicml-mpt-7b-chat-gguf/tree/main)
## Requirements

View file

@ -523,32 +523,33 @@ def _optimize_post(model, lightweight_bmm=False):
bloom_attention_forward
)
elif "falcon" in model.config.model_type or "RefinedWeb" in model.config.model_type:
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
if "RWForCausalLM" in model.config.architectures:
if model.config.hidden_size == 4544:
# falcon-7b need to check performance drop after kv cache support.
# from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
# convert_forward(model,
# module.Attention,
# rw_attention_forward_7b
# )
pass
else:
# falcon-40b
from bigdl.llm.transformers.models.falcon import rw_attention_forward_40b
convert_forward(model,
module.Attention,
rw_attention_forward_40b
)
elif "FalconForCausalLM" in model.config.architectures:
if model.config.hidden_size != 4544:
# falcon-180b and new falcon-40b
from bigdl.llm.transformers.models.falcon import falcon_attention_forward
convert_forward(model,
module.FalconAttention,
falcon_attention_forward
)
if model.config.architectures is not None:
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
if "RWForCausalLM" in model.config.architectures:
if model.config.hidden_size == 4544:
# falcon-7b need to check performance drop after kv cache support.
# from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
# convert_forward(model,
# module.Attention,
# rw_attention_forward_7b
# )
pass
else:
# falcon-40b
from bigdl.llm.transformers.models.falcon import rw_attention_forward_40b
convert_forward(model,
module.Attention,
rw_attention_forward_40b
)
elif "FalconForCausalLM" in model.config.architectures:
if model.config.hidden_size != 4544:
# falcon-180b and new falcon-40b
from bigdl.llm.transformers.models.falcon import falcon_attention_forward
convert_forward(model,
module.FalconAttention,
falcon_attention_forward
)
elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
# baichuan2
if model.config.hidden_size == 4096:

View file

@ -57,6 +57,9 @@ def load_gguf_model(fpath: str, dtype: torch.dtype = torch.float):
elif model_family == "bloom":
from .models.bloom import load_gguf_bloom
model, tokenizer = load_gguf_bloom(loader, dtype)
elif model_family == "falcon":
from .models.falcon import load_gguf_falcon
model, tokenizer = load_gguf_falcon(loader, dtype)
elif model_family == "mpt":
from .models.mpt import load_gguf_mpt
model, tokenizer = load_gguf_mpt(loader, dtype)

View file

@ -0,0 +1,112 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import torch
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from transformers import FalconConfig, FalconForCausalLM, PreTrainedTokenizerFast
from ..gguf import GGUFFileLoader
def load_gguf_falcon(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
config = loader.config
falcon_config = FalconConfig(
vocab_size=len(config['tokenizer.ggml.tokens']),
hidden_size=config['falcon.embedding_length'],
num_hidden_layers=config['falcon.block_count'],
num_attention_heads=config['falcon.attention.head_count'],
num_kv_heads=config['falcon.attention.head_count_kv'],
max_position_embeddings=config['falcon.context_length'],
layer_norm_epsilon=config['falcon.attention.layer_norm_epsilon'],
use_cache=True,
bos_token_id=config['tokenizer.ggml.bos_token_id'],
eos_token_id=config['tokenizer.ggml.eos_token_id'],
# architectures="FalconForCausalLM",
)
ckpt = loader.tensors(dtype)
n_head = config['falcon.attention.head_count']
n_head_kv = config['falcon.attention.head_count_kv']
head_dim = config['falcon.embedding_length'] // n_head
ckpt = restore_falcon_weight(ckpt, n_head, n_head_kv, head_dim)
state_dict = {}
state_dict['transformer.word_embeddings.weight'] = ckpt['token_embd.weight']
state_dict['transformer.ln_f.weight'] = ckpt['output_norm.weight']
state_dict['transformer.ln_f.bias'] = ckpt['output_norm.bias']
state_dict['lm_head.weight'] = ckpt['output.weight']
for i in range(config['falcon.block_count']):
state_dict[f'transformer.h.{i}.self_attention.query_key_value.weight'] = \
ckpt[f'blk.{i}.attn_qkv.weight']
state_dict[f'transformer.h.{i}.self_attention.dense.weight'] = \
ckpt[f'blk.{i}.attn_output.weight']
state_dict[f'transformer.h.{i}.mlp.dense_h_to_4h.weight'] = \
ckpt[f'blk.{i}.ffn_up.weight']
state_dict[f'transformer.h.{i}.mlp.dense_4h_to_h.weight'] = \
ckpt[f'blk.{i}.ffn_down.weight']
state_dict[f'transformer.h.{i}.input_layernorm.weight'] = \
ckpt[f'blk.{i}.attn_norm.weight']
state_dict[f'transformer.h.{i}.input_layernorm.bias'] = \
ckpt[f'blk.{i}.attn_norm.bias']
with init_empty_weights():
model = FalconForCausalLM(falcon_config)
for name, weight in state_dict.items():
set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype)
model = model.cpu()
pieces, merges = loader.tokenizer_pieces()
current_directory = os.path.dirname(os.path.abspath(__file__))
token_file = current_directory + "/model_implement/falcon/tokenizer.json"
import json
with open(token_file, 'r') as file:
data = json.load(file)
vocab = {}
# load and replace vocab and merges
for i in range(len(pieces)):
token = pieces[i].piece
score = int(pieces[i].score)
vocab[token] = score
data['model']['merges'] = merges
data['model']['vocab'] = vocab
with open(token_file, 'w') as file:
json.dump(data, file, indent=4)
tokenizer = PreTrainedTokenizerFast(tokenizer_file=token_file)
return model, tokenizer
def restore_falcon_weight(ckpt: dict, n_head: int, n_head_kv: int, head_dim: int):
# see https://github.com/ggerganov/llama.cpp/blob/
# master/convert-hf-to-gguf.py#L666
import numpy as np
for name, weight in ckpt.items():
if name.endswith("attn_qkv.weight"):
part1, part2, part3 = np.split(weight.reshape(-1, head_dim * n_head),
[n_head * head_dim, (n_head + n_head_kv) * head_dim],
axis=0)
part1 = part1.reshape((n_head_kv, n_head // n_head_kv, head_dim, head_dim * n_head))
part2 = part2.reshape((n_head_kv, 1, head_dim, head_dim * n_head))
part3 = part3.reshape((n_head_kv, 1, head_dim, head_dim * n_head))
data = torch.cat([part1, part2, part3], dim=1)
ckpt[name] = data.reshape(-1, head_dim * n_head)
return ckpt

View file

@ -0,0 +1,160 @@
{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [
{
"id": 0,
"content": ">>TITLE<<",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 1,
"content": ">>ABSTRACT<<",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 2,
"content": ">>INTRODUCTION<<",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 3,
"content": ">>SUMMARY<<",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 4,
"content": ">>COMMENT<<",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 5,
"content": ">>ANSWER<<",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 6,
"content": ">>QUESTION<<",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 7,
"content": ">>DOMAIN<<",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 8,
"content": ">>PREFIX<<",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 9,
"content": ">>SUFFIX<<",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 10,
"content": ">>MIDDLE<<",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 11,
"content": "<|endoftext|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
],
"normalizer": null,
"pre_tokenizer": {
"type": "Sequence",
"pretokenizers": [
{
"type": "Punctuation",
"behavior": "Contiguous"
},
{
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": true,
"use_regex": true
},
{
"type": "Digits",
"individual_digits": false
},
{
"type": "Split",
"pattern": {
"Regex": "[0-9][0-9][0-9]"
},
"behavior": "Isolated",
"invert": false
}
]
},
"post_processor": null,
"decoder": {
"type": "ByteLevel",
"add_prefix_space": true,
"trim_offsets": true,
"use_regex": true
},
"model": {
"type": "BPE",
"dropout": null,
"unk_token": null,
"continuing_subword_prefix": null,
"end_of_word_suffix": null,
"fuse_unk": false,
"vocab": null,
"merges": null
}
}