diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md index 87a6cdd0..15654003 100644 --- a/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md +++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md @@ -7,6 +7,7 @@ In this directory, you will find examples on how to load GGUF model into `bigdl- - [Mixtral-8x7B-v0.1-GGUF](https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF) - [Baichuan2-7B-Chat-GGUF](https://huggingface.co/second-state/Baichuan2-7B-Chat-GGUF/tree/main) - [Bloomz-7b1-GGUF](https://huggingface.co/hzjane/bloomz-7b1-gguf) +- [falcon-7b-quantized-gguf](https://huggingface.co/xaviviro/falcon-7b-quantized-gguf/tree/main) - [mpt-7b-chat-gguf](https://huggingface.co/maddes8cht/mosaicml-mpt-7b-chat-gguf/tree/main) ## Requirements diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md index 368f5667..8740ac2c 100644 --- a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md +++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md @@ -7,6 +7,7 @@ In this directory, you will find examples on how to load GGUF model into `bigdl- - [Mixtral-8x7B-v0.1-GGUF](https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF) - [Baichuan2-7B-Chat-GGUF](https://huggingface.co/second-state/Baichuan2-7B-Chat-GGUF/tree/main) - [Bloomz-7b1-GGUF](https://huggingface.co/hzjane/bloomz-7b1-gguf) +- [falcon-7b-quantized-gguf](https://huggingface.co/xaviviro/falcon-7b-quantized-gguf/tree/main) - [mpt-7b-chat-gguf](https://huggingface.co/maddes8cht/mosaicml-mpt-7b-chat-gguf/tree/main) ## Requirements diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index a270905b..8f704ed6 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -523,32 +523,33 @@ def _optimize_post(model, lightweight_bmm=False): bloom_attention_forward ) elif "falcon" in model.config.model_type or "RefinedWeb" in model.config.model_type: - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) - if "RWForCausalLM" in model.config.architectures: - if model.config.hidden_size == 4544: - # falcon-7b need to check performance drop after kv cache support. - # from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b - # convert_forward(model, - # module.Attention, - # rw_attention_forward_7b - # ) - pass - else: - # falcon-40b - from bigdl.llm.transformers.models.falcon import rw_attention_forward_40b - convert_forward(model, - module.Attention, - rw_attention_forward_40b - ) - elif "FalconForCausalLM" in model.config.architectures: - if model.config.hidden_size != 4544: - # falcon-180b and new falcon-40b - from bigdl.llm.transformers.models.falcon import falcon_attention_forward - convert_forward(model, - module.FalconAttention, - falcon_attention_forward - ) + if model.config.architectures is not None: + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + if "RWForCausalLM" in model.config.architectures: + if model.config.hidden_size == 4544: + # falcon-7b need to check performance drop after kv cache support. + # from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b + # convert_forward(model, + # module.Attention, + # rw_attention_forward_7b + # ) + pass + else: + # falcon-40b + from bigdl.llm.transformers.models.falcon import rw_attention_forward_40b + convert_forward(model, + module.Attention, + rw_attention_forward_40b + ) + elif "FalconForCausalLM" in model.config.architectures: + if model.config.hidden_size != 4544: + # falcon-180b and new falcon-40b + from bigdl.llm.transformers.models.falcon import falcon_attention_forward + convert_forward(model, + module.FalconAttention, + falcon_attention_forward + ) elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696: # baichuan2 if model.config.hidden_size == 4096: diff --git a/python/llm/src/bigdl/llm/transformers/gguf/api.py b/python/llm/src/bigdl/llm/transformers/gguf/api.py index 08c29165..e9e5fa29 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/api.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/api.py @@ -57,6 +57,9 @@ def load_gguf_model(fpath: str, dtype: torch.dtype = torch.float): elif model_family == "bloom": from .models.bloom import load_gguf_bloom model, tokenizer = load_gguf_bloom(loader, dtype) + elif model_family == "falcon": + from .models.falcon import load_gguf_falcon + model, tokenizer = load_gguf_falcon(loader, dtype) elif model_family == "mpt": from .models.mpt import load_gguf_mpt model, tokenizer = load_gguf_mpt(loader, dtype) diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/falcon.py b/python/llm/src/bigdl/llm/transformers/gguf/models/falcon.py new file mode 100644 index 00000000..d7d708db --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/falcon.py @@ -0,0 +1,112 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import torch +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from transformers import FalconConfig, FalconForCausalLM, PreTrainedTokenizerFast + +from ..gguf import GGUFFileLoader + + +def load_gguf_falcon(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): + config = loader.config + + falcon_config = FalconConfig( + vocab_size=len(config['tokenizer.ggml.tokens']), + hidden_size=config['falcon.embedding_length'], + num_hidden_layers=config['falcon.block_count'], + num_attention_heads=config['falcon.attention.head_count'], + num_kv_heads=config['falcon.attention.head_count_kv'], + max_position_embeddings=config['falcon.context_length'], + layer_norm_epsilon=config['falcon.attention.layer_norm_epsilon'], + use_cache=True, + bos_token_id=config['tokenizer.ggml.bos_token_id'], + eos_token_id=config['tokenizer.ggml.eos_token_id'], + # architectures="FalconForCausalLM", + ) + + ckpt = loader.tensors(dtype) + n_head = config['falcon.attention.head_count'] + n_head_kv = config['falcon.attention.head_count_kv'] + head_dim = config['falcon.embedding_length'] // n_head + ckpt = restore_falcon_weight(ckpt, n_head, n_head_kv, head_dim) + + state_dict = {} + state_dict['transformer.word_embeddings.weight'] = ckpt['token_embd.weight'] + state_dict['transformer.ln_f.weight'] = ckpt['output_norm.weight'] + state_dict['transformer.ln_f.bias'] = ckpt['output_norm.bias'] + state_dict['lm_head.weight'] = ckpt['output.weight'] + for i in range(config['falcon.block_count']): + state_dict[f'transformer.h.{i}.self_attention.query_key_value.weight'] = \ + ckpt[f'blk.{i}.attn_qkv.weight'] + state_dict[f'transformer.h.{i}.self_attention.dense.weight'] = \ + ckpt[f'blk.{i}.attn_output.weight'] + state_dict[f'transformer.h.{i}.mlp.dense_h_to_4h.weight'] = \ + ckpt[f'blk.{i}.ffn_up.weight'] + state_dict[f'transformer.h.{i}.mlp.dense_4h_to_h.weight'] = \ + ckpt[f'blk.{i}.ffn_down.weight'] + state_dict[f'transformer.h.{i}.input_layernorm.weight'] = \ + ckpt[f'blk.{i}.attn_norm.weight'] + state_dict[f'transformer.h.{i}.input_layernorm.bias'] = \ + ckpt[f'blk.{i}.attn_norm.bias'] + + with init_empty_weights(): + model = FalconForCausalLM(falcon_config) + + for name, weight in state_dict.items(): + set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype) + + model = model.cpu() + + pieces, merges = loader.tokenizer_pieces() + + current_directory = os.path.dirname(os.path.abspath(__file__)) + token_file = current_directory + "/model_implement/falcon/tokenizer.json" + import json + with open(token_file, 'r') as file: + data = json.load(file) + vocab = {} + # load and replace vocab and merges + for i in range(len(pieces)): + token = pieces[i].piece + score = int(pieces[i].score) + vocab[token] = score + data['model']['merges'] = merges + data['model']['vocab'] = vocab + + with open(token_file, 'w') as file: + json.dump(data, file, indent=4) + tokenizer = PreTrainedTokenizerFast(tokenizer_file=token_file) + return model, tokenizer + + +def restore_falcon_weight(ckpt: dict, n_head: int, n_head_kv: int, head_dim: int): + # see https://github.com/ggerganov/llama.cpp/blob/ + # master/convert-hf-to-gguf.py#L666 + import numpy as np + for name, weight in ckpt.items(): + if name.endswith("attn_qkv.weight"): + part1, part2, part3 = np.split(weight.reshape(-1, head_dim * n_head), + [n_head * head_dim, (n_head + n_head_kv) * head_dim], + axis=0) + part1 = part1.reshape((n_head_kv, n_head // n_head_kv, head_dim, head_dim * n_head)) + part2 = part2.reshape((n_head_kv, 1, head_dim, head_dim * n_head)) + part3 = part3.reshape((n_head_kv, 1, head_dim, head_dim * n_head)) + data = torch.cat([part1, part2, part3], dim=1) + ckpt[name] = data.reshape(-1, head_dim * n_head) + return ckpt diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/model_implement/falcon/tokenizer.json b/python/llm/src/bigdl/llm/transformers/gguf/models/model_implement/falcon/tokenizer.json new file mode 100644 index 00000000..ee7c6ded --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/model_implement/falcon/tokenizer.json @@ -0,0 +1,160 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": ">>TITLE<<", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 1, + "content": ">>ABSTRACT<<", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 2, + "content": ">>INTRODUCTION<<", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 3, + "content": ">>SUMMARY<<", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 4, + "content": ">>COMMENT<<", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 5, + "content": ">>ANSWER<<", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 6, + "content": ">>QUESTION<<", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 7, + "content": ">>DOMAIN<<", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 8, + "content": ">>PREFIX<<", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 9, + "content": ">>SUFFIX<<", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 10, + "content": ">>MIDDLE<<", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 11, + "content": "<|endoftext|>", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "Sequence", + "pretokenizers": [ + { + "type": "Punctuation", + "behavior": "Contiguous" + }, + { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": true + }, + { + "type": "Digits", + "individual_digits": false + }, + { + "type": "Split", + "pattern": { + "Regex": "[0-9][0-9][0-9]" + }, + "behavior": "Isolated", + "invert": false + } + ] + }, + "post_processor": null, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true, + "use_regex": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "vocab": null, + "merges": null + } +} \ No newline at end of file