diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md index 13e6079f..87a6cdd0 100644 --- a/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md +++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md @@ -7,6 +7,7 @@ In this directory, you will find examples on how to load GGUF model into `bigdl- - [Mixtral-8x7B-v0.1-GGUF](https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF) - [Baichuan2-7B-Chat-GGUF](https://huggingface.co/second-state/Baichuan2-7B-Chat-GGUF/tree/main) - [Bloomz-7b1-GGUF](https://huggingface.co/hzjane/bloomz-7b1-gguf) +- [mpt-7b-chat-gguf](https://huggingface.co/maddes8cht/mosaicml-mpt-7b-chat-gguf/tree/main) ## Requirements To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../../../README.md#system-support) for more information. diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md index 26346b90..368f5667 100644 --- a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md +++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GGUF/README.md @@ -7,6 +7,7 @@ In this directory, you will find examples on how to load GGUF model into `bigdl- - [Mixtral-8x7B-v0.1-GGUF](https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF) - [Baichuan2-7B-Chat-GGUF](https://huggingface.co/second-state/Baichuan2-7B-Chat-GGUF/tree/main) - [Bloomz-7b1-GGUF](https://huggingface.co/hzjane/bloomz-7b1-gguf) +- [mpt-7b-chat-gguf](https://huggingface.co/maddes8cht/mosaicml-mpt-7b-chat-gguf/tree/main) ## Requirements To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../../../README.md#system-support) for more information. diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 44f98379..aed54457 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -502,14 +502,15 @@ def _optimize_post(model, lightweight_bmm=False): chatglm_attention_forward ) elif "mpt" in model.config.model_type: - modeling_module_name = model.__class__.__module__ - attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention" - module = importlib.import_module(attention_module_name) - from bigdl.llm.transformers.models.mpt import mpt_multihead_attention_forward - convert_forward(model, - module.MultiheadAttention, - mpt_multihead_attention_forward - ) + if model.config.architectures is not None: + modeling_module_name = model.__class__.__module__ + attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention" + module = importlib.import_module(attention_module_name) + from bigdl.llm.transformers.models.mpt import mpt_multihead_attention_forward + convert_forward(model, + module.MultiheadAttention, + mpt_multihead_attention_forward + ) elif "gptj" in model.config.model_type: # dolly-v1-6b modeling_module_name = model.__class__.__module__ diff --git a/python/llm/src/bigdl/llm/transformers/gguf/api.py b/python/llm/src/bigdl/llm/transformers/gguf/api.py index b7a5e644..08c29165 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/api.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/api.py @@ -57,6 +57,9 @@ def load_gguf_model(fpath: str, dtype: torch.dtype = torch.float): elif model_family == "bloom": from .models.bloom import load_gguf_bloom model, tokenizer = load_gguf_bloom(loader, dtype) + elif model_family == "mpt": + from .models.mpt import load_gguf_mpt + model, tokenizer = load_gguf_mpt(loader, dtype) else: invalidInputError(False, f"Unsupported model family: {model_family}") diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/model_implement/mpt/tokenizer.json b/python/llm/src/bigdl/llm/transformers/gguf/models/model_implement/mpt/tokenizer.json new file mode 100644 index 00000000..7c40dff6 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/model_implement/mpt/tokenizer.json @@ -0,0 +1,282 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "<|endoftext|>", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 1, + "content": "<|padding|>", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 50254, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50255, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50256, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50257, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50258, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50259, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50260, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50261, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50262, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50263, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50264, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50265, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50266, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50267, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50268, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50269, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50270, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50271, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50272, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50273, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50274, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50275, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50276, + "content": " ", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": false + }, + { + "id": 50277, + "content": "<|im_start|>", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 50278, + "content": "<|im_end|>", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": { + "type": "NFC" + }, + "pre_tokenizer": { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": true + }, + "post_processor": { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": true + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "vocab": null, + "merges": null + } +} \ No newline at end of file diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/mpt.py b/python/llm/src/bigdl/llm/transformers/gguf/models/mpt.py new file mode 100644 index 00000000..d0830d33 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/mpt.py @@ -0,0 +1,88 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import torch +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from transformers import MptConfig, MptForCausalLM, GPTNeoXTokenizerFast + +from ..gguf import GGUFFileLoader + + +def load_gguf_mpt(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): + config = loader.config + + mpt_config = MptConfig( + vocab_size=len(config['tokenizer.ggml.tokens']), + d_model=config['mpt.embedding_length'], + n_layers=config['mpt.block_count'], + n_heads=config['mpt.attention.head_count'], + max_position_embeddings=config['mpt.context_length'], + layer_norm_epsilon=config['mpt.attention.layer_norm_epsilon'], + bos_token_id=config['tokenizer.ggml.bos_token_id'], + eos_token_id=config['tokenizer.ggml.eos_token_id'], + unknown_token_id=config['tokenizer.ggml.unknown_token_id'], + ) + + ckpt = loader.tensors(dtype) + + state_dict = {} + state_dict['transformer.wte.weight'] = ckpt['token_embd.weight'] + state_dict['transformer.norm_f.weight'] = ckpt['output_norm.weight'] + state_dict['lm_head.weight'] = ckpt['output.weight'] + for i in range(config['mpt.block_count']): + state_dict[f'transformer.blocks.{i}.attn.Wqkv.weight'] = \ + ckpt[f'blk.{i}.attn_qkv.weight'] + state_dict[f'transformer.blocks.{i}.attn.out_proj.weight'] = \ + ckpt[f'blk.{i}.attn_output.weight'] + state_dict[f'transformer.blocks.{i}.norm_2.weight'] = \ + ckpt[f'blk.{i}.ffn_norm.weight'] + state_dict[f'transformer.blocks.{i}.ffn.up_proj.weight'] = \ + ckpt[f'blk.{i}.ffn_up.weight'] + state_dict[f'transformer.blocks.{i}.ffn.down_proj.weight'] = \ + ckpt[f'blk.{i}.ffn_down.weight'] + state_dict[f'transformer.blocks.{i}.norm_1.weight'] = \ + ckpt[f'blk.{i}.attn_norm.weight'] + + with init_empty_weights(): + model = MptForCausalLM(mpt_config) + + for name, weight in state_dict.items(): + set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype) + + model = model.cpu() + + pieces, merges = loader.tokenizer_pieces() + + current_directory = os.path.dirname(os.path.abspath(__file__)) + token_file = current_directory + "/model_implement/mpt/tokenizer.json" + import json + with open(token_file, 'r') as file: + data = json.load(file) + vocab = {} + # load and replace vocab and merges + for i in range(len(pieces)): + token = pieces[i].piece + score = int(pieces[i].score) + vocab[token] = score + data['model']['merges'] = merges + data['model']['vocab'] = vocab + + with open(token_file, 'w') as file: + json.dump(data, file, indent=4) + tokenizer = GPTNeoXTokenizerFast(tokenizer_file=token_file) + return model, tokenizer