Load Mixtral GGUF Model (#9690)
* Load Mixtral GGUF Model * refactor * fix empty tensor when to cpu * update gpu and cpu readmes * add dtype when set tensor into module
This commit is contained in:
		
							parent
							
								
									d0a3095b97
								
							
						
					
					
						commit
						1fa7793fc0
					
				
					 4 changed files with 117 additions and 6 deletions
				
			
		| 
						 | 
					@ -4,12 +4,13 @@ In this directory, you will find examples on how to load GGUF model into `bigdl-
 | 
				
			||||||
## Verified Models(Q4_0)
 | 
					## Verified Models(Q4_0)
 | 
				
			||||||
- [Llama-2-7B-Chat-GGUF](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/tree/main)
 | 
					- [Llama-2-7B-Chat-GGUF](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/tree/main)
 | 
				
			||||||
- [Mistral-7B-Instruct-v0.1-GGUF](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF)
 | 
					- [Mistral-7B-Instruct-v0.1-GGUF](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF)
 | 
				
			||||||
 | 
					- [Mixtral-8x7B-v0.1-GGUF](https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF)
 | 
				
			||||||
- [Baichuan2-7B-Chat-GGUF](https://huggingface.co/second-state/Baichuan2-7B-Chat-GGUF/tree/main)
 | 
					- [Baichuan2-7B-Chat-GGUF](https://huggingface.co/second-state/Baichuan2-7B-Chat-GGUF/tree/main)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Requirements
 | 
					## Requirements
 | 
				
			||||||
To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../../../README.md#system-support) for more information.
 | 
					To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../../../README.md#system-support) for more information.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
**Important: Please make sure you have installed `transformers==4.33.0` to run the example.**
 | 
					**Important: Please make sure you have installed `transformers==4.36.0` to run the example.**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Example: Load gguf model using `from_gguf()` API
 | 
					## Example: Load gguf model using `from_gguf()` API
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4,12 +4,13 @@ In this directory, you will find examples on how to load GGUF model into `bigdl-
 | 
				
			||||||
## Verified Models(Q4_0)
 | 
					## Verified Models(Q4_0)
 | 
				
			||||||
- [Llama-2-7B-Chat-GGUF](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/tree/main)
 | 
					- [Llama-2-7B-Chat-GGUF](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/tree/main)
 | 
				
			||||||
- [Mistral-7B-Instruct-v0.1-GGUF](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF)
 | 
					- [Mistral-7B-Instruct-v0.1-GGUF](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF)
 | 
				
			||||||
 | 
					- [Mixtral-8x7B-v0.1-GGUF](https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF)
 | 
				
			||||||
- [Baichuan2-7B-Chat-GGUF](https://huggingface.co/second-state/Baichuan2-7B-Chat-GGUF/tree/main)
 | 
					- [Baichuan2-7B-Chat-GGUF](https://huggingface.co/second-state/Baichuan2-7B-Chat-GGUF/tree/main)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Requirements
 | 
					## Requirements
 | 
				
			||||||
To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../../../README.md#system-support) for more information.
 | 
					To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../../../README.md#system-support) for more information.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
**Important: Please make sure you have installed `transformers==4.33.0` to run the example.**
 | 
					**Important: Please make sure you have installed `transformers==4.36.0` to run the example.**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Example: Load gguf model using `from_gguf()` API
 | 
					## Example: Load gguf model using `from_gguf()` API
 | 
				
			||||||
In the example [generate.py](./generate.py), we show a basic use case to load a GGUF LLaMA2 model into `bigdl-llm` using `from_gguf()` API, with BigDL-LLM optimizations.
 | 
					In the example [generate.py](./generate.py), we show a basic use case to load a GGUF LLaMA2 model into `bigdl-llm` using `from_gguf()` API, with BigDL-LLM optimizations.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -40,17 +40,19 @@ def load_gguf_model(fpath: str, dtype: torch.dtype = torch.float):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with torch.no_grad():
 | 
					    with torch.no_grad():
 | 
				
			||||||
        if model_family == "llama":
 | 
					        if model_family == "llama":
 | 
				
			||||||
            model_name = loader.config["general.name"].lower()
 | 
					            general_name = loader.config["general.name"].lower()
 | 
				
			||||||
            if "mistral" in model_name:
 | 
					            if "mixtral" in general_name:
 | 
				
			||||||
 | 
					                # mixtral, which also enjoys a general architecture of llama
 | 
				
			||||||
 | 
					                from .models.mixtral import load_gguf_mixtral
 | 
				
			||||||
 | 
					                model, tokenizer = load_gguf_mixtral(loader, dtype)
 | 
				
			||||||
 | 
					            elif "mistral" in general_name:
 | 
				
			||||||
                from .models.mistral import load_gguf_mistral
 | 
					                from .models.mistral import load_gguf_mistral
 | 
				
			||||||
                model, tokenizer = load_gguf_mistral(loader, dtype)
 | 
					                model, tokenizer = load_gguf_mistral(loader, dtype)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                from .models.llama import load_gguf_llama
 | 
					                from .models.llama import load_gguf_llama
 | 
				
			||||||
 | 
					 | 
				
			||||||
                model, tokenizer = load_gguf_llama(loader, dtype)
 | 
					                model, tokenizer = load_gguf_llama(loader, dtype)
 | 
				
			||||||
        elif model_family == "baichuan":
 | 
					        elif model_family == "baichuan":
 | 
				
			||||||
            from .models.baichuan import load_gguf_baichuan
 | 
					            from .models.baichuan import load_gguf_baichuan
 | 
				
			||||||
 | 
					 | 
				
			||||||
            model, tokenizer = load_gguf_baichuan(loader, dtype)
 | 
					            model, tokenizer = load_gguf_baichuan(loader, dtype)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            invalidInputError(False, f"Unsupported model family: {model_family}")
 | 
					            invalidInputError(False, f"Unsupported model family: {model_family}")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										107
									
								
								python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,107 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from accelerate import init_empty_weights
 | 
				
			||||||
 | 
					from accelerate.utils import set_module_tensor_to_device
 | 
				
			||||||
 | 
					from tempfile import NamedTemporaryFile
 | 
				
			||||||
 | 
					from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..gguf import GGUFFileLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
 | 
				
			||||||
 | 
					    # mixtral enjoys a general architecture of llma
 | 
				
			||||||
 | 
					    # e.g. it applies llama tokenizer
 | 
				
			||||||
 | 
					    config = loader.config
 | 
				
			||||||
 | 
					    num_local_experts = config['llama.expert_count']
 | 
				
			||||||
 | 
					    num_experts_per_tok = config['llama.expert_used_count']
 | 
				
			||||||
 | 
					    n_head = config['llama.attention.head_count']
 | 
				
			||||||
 | 
					    n_head_kv = config['llama.attention.head_count_kv']
 | 
				
			||||||
 | 
					    hidden_size = config['llama.embedding_length']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    mixtral_config = MixtralConfig(
 | 
				
			||||||
 | 
					        vocab_size=len(config['tokenizer.ggml.tokens']),
 | 
				
			||||||
 | 
					        hidden_size=hidden_size,
 | 
				
			||||||
 | 
					        intermediate_size=config['llama.feed_forward_length'],
 | 
				
			||||||
 | 
					        num_hidden_layers=config['llama.block_count'],
 | 
				
			||||||
 | 
					        num_attention_heads=config['llama.attention.head_count'],
 | 
				
			||||||
 | 
					        num_key_value_heads=config['llama.attention.head_count_kv'],
 | 
				
			||||||
 | 
					        max_position_embeddings=config['llama.context_length'],
 | 
				
			||||||
 | 
					        rms_norm_eps=config['llama.attention.layer_norm_rms_epsilon'],
 | 
				
			||||||
 | 
					        pad_token_id=config['tokenizer.ggml.padding_token_id'],
 | 
				
			||||||
 | 
					        bos_token_id=config['tokenizer.ggml.bos_token_id'],
 | 
				
			||||||
 | 
					        eos_token_id=config['tokenizer.ggml.eos_token_id'],
 | 
				
			||||||
 | 
					        num_local_experts=num_local_experts,
 | 
				
			||||||
 | 
					        num_experts_per_tok=num_experts_per_tok,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ckpt = loader.tensors(dtype)
 | 
				
			||||||
 | 
					    from .llama import restore_llama_weight
 | 
				
			||||||
 | 
					    ckpt = restore_llama_weight(ckpt, n_head, n_head_kv)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    state_dict = {}
 | 
				
			||||||
 | 
					    state_dict['model.embed_tokens.weight'] = ckpt['token_embd.weight']
 | 
				
			||||||
 | 
					    state_dict['model.norm.weight'] = ckpt['output_norm.weight']
 | 
				
			||||||
 | 
					    state_dict['lm_head.weight'] = ckpt['output.weight']
 | 
				
			||||||
 | 
					    for i in range(config['llama.block_count']):
 | 
				
			||||||
 | 
					        state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = \
 | 
				
			||||||
 | 
					            ckpt[f'blk.{i}.attn_q.weight']
 | 
				
			||||||
 | 
					        state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = \
 | 
				
			||||||
 | 
					            ckpt[f'blk.{i}.attn_k.weight']
 | 
				
			||||||
 | 
					        state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = \
 | 
				
			||||||
 | 
					            ckpt[f'blk.{i}.attn_v.weight']
 | 
				
			||||||
 | 
					        state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = \
 | 
				
			||||||
 | 
					            ckpt[f'blk.{i}.attn_output.weight']
 | 
				
			||||||
 | 
					        state_dict[f'model.layers.{i}.input_layernorm.weight'] = \
 | 
				
			||||||
 | 
					            ckpt[f'blk.{i}.attn_norm.weight']
 | 
				
			||||||
 | 
					        state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = \
 | 
				
			||||||
 | 
					            ckpt[f'blk.{i}.ffn_norm.weight']
 | 
				
			||||||
 | 
					        state_dict[f'model.layers.{i}.block_sparse_moe.gate.weight'] = \
 | 
				
			||||||
 | 
					            ckpt[f'blk.{i}.ffn_gate_inp.weight'].reshape(num_local_experts, hidden_size)
 | 
				
			||||||
 | 
					        for j in range(num_local_experts):
 | 
				
			||||||
 | 
					            state_dict[f'model.layers.{i}.block_sparse_moe.experts.{j}.w1.weight'] = \
 | 
				
			||||||
 | 
					                (ckpt[f'blk.{i}.ffn_gate.{j}.weight'])
 | 
				
			||||||
 | 
					            state_dict[f'model.layers.{i}.block_sparse_moe.experts.{j}.w2.weight'] = \
 | 
				
			||||||
 | 
					                ckpt[f'blk.{i}.ffn_down.{j}.weight']
 | 
				
			||||||
 | 
					            state_dict[f'model.layers.{i}.block_sparse_moe.experts.{j}.w3.weight'] = \
 | 
				
			||||||
 | 
					                ckpt[f'blk.{i}.ffn_up.{j}.weight']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with init_empty_weights():
 | 
				
			||||||
 | 
					        model = MixtralForCausalLM(mixtral_config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for name, weight in state_dict.items():
 | 
				
			||||||
 | 
					        set_module_tensor_to_device(model, name, "cpu", weight, dytype=dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model = model.cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from transformers.convert_slow_tokenizer import import_protobuf
 | 
				
			||||||
 | 
					    spm_pb2 = import_protobuf("Failed to import protobuf")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tokenizer_pieces = loader.tokenizer_pieces()
 | 
				
			||||||
 | 
					    trainer_spec = spm_pb2.TrainerSpec(byte_fallback=True,
 | 
				
			||||||
 | 
					                                       model_type=spm_pb2.TrainerSpec.ModelType.BPE)
 | 
				
			||||||
 | 
					    proto = spm_pb2.ModelProto(pieces=tokenizer_pieces, trainer_spec=trainer_spec)
 | 
				
			||||||
 | 
					    proto = proto.SerializeToString()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with NamedTemporaryFile(delete=False) as f:
 | 
				
			||||||
 | 
					        f.write(proto)
 | 
				
			||||||
 | 
					        f.close()
 | 
				
			||||||
 | 
					        tokenizer = LlamaTokenizer(f.name)
 | 
				
			||||||
 | 
					        os.remove(f.name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return model, tokenizer
 | 
				
			||||||
		Loading…
	
		Reference in a new issue