gguf memory optimization for mixtral (#9939)
This commit is contained in:
		
							parent
							
								
									610b5226be
								
							
						
					
					
						commit
						967714bac8
					
				
					 3 changed files with 11 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -202,7 +202,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
        is_linear, linear_args = is_linear_module(module)
 | 
			
		||||
        if is_linear and name not in modules_to_not_convert:
 | 
			
		||||
            # Check if the current key is not in the `modules_to_not_convert`
 | 
			
		||||
            if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
 | 
			
		||||
            if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
 | 
			
		||||
                    module.weight.data.device.type != 'meta' and
 | 
			
		||||
                    not isinstance(module, LowBitLinear)):
 | 
			
		||||
                in_features, out_features, mp_group = linear_args
 | 
			
		||||
                with init_empty_weights():
 | 
			
		||||
                    new_linear = None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -25,6 +25,7 @@ from ..gguf import GGUFFileLoader
 | 
			
		|||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
 | 
			
		||||
                      low_bit='sym_int4'):
 | 
			
		||||
    config = loader.config
 | 
			
		||||
| 
						 | 
				
			
			@ -78,7 +79,7 @@ def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
 | 
			
		|||
        else:
 | 
			
		||||
            set_module_tensor_to_device(model, module_name, "cpu", tensor, dtype=dtype)
 | 
			
		||||
        model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
    tensor_loader = loader.tensor_loader
 | 
			
		||||
    tensor_loader.load_while_process(process_mistral)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -22,9 +22,12 @@ from tempfile import NamedTemporaryFile
 | 
			
		|||
from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer
 | 
			
		||||
 | 
			
		||||
from ..gguf import GGUFFileLoader
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
 | 
			
		||||
def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
 | 
			
		||||
                      low_bit='sym_int4'):
 | 
			
		||||
    # mixtral enjoys a general architecture of llma
 | 
			
		||||
    # e.g. it applies llama tokenizer
 | 
			
		||||
    config = loader.config
 | 
			
		||||
| 
						 | 
				
			
			@ -33,6 +36,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
 | 
			
		|||
    n_head = config['llama.attention.head_count']
 | 
			
		||||
    n_head_kv = config['llama.attention.head_count_kv']
 | 
			
		||||
    hidden_size = config['llama.embedding_length']
 | 
			
		||||
    qtype = ggml_tensor_qtype[low_bit]
 | 
			
		||||
 | 
			
		||||
    mixtral_config = MixtralConfig(
 | 
			
		||||
        vocab_size=len(config['tokenizer.ggml.tokens']),
 | 
			
		||||
| 
						 | 
				
			
			@ -81,6 +85,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
 | 
			
		|||
                                    "cpu",
 | 
			
		||||
                                    tensor,
 | 
			
		||||
                                    dtype=dtype)
 | 
			
		||||
        model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name)
 | 
			
		||||
 | 
			
		||||
    tensor_loader = loader.tensor_loader
 | 
			
		||||
    tensor_loader.load_while_process(process_mixtral)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue