GGUF load memory optimization (#9913)
* block-wise * convert linear for module * revert * Fix PEP8 checks Error
This commit is contained in:
		
							parent
							
								
									8643b62521
								
							
						
					
					
						commit
						b909c5c9c2
					
				
					 2 changed files with 197 additions and 46 deletions
				
			
		| 
						 | 
				
			
			@ -327,6 +327,145 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
    return model, has_been_replaced
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def replace_with_low_bit_linear_for_module(model, qtype, module_name=None,
 | 
			
		||||
                                           modules_to_not_convert=None, current_key_name=None,
 | 
			
		||||
                                           convert_shape_only=False, torch_dtype="auto"):
 | 
			
		||||
    from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
 | 
			
		||||
        FP16Linear, BF16Linear
 | 
			
		||||
    has_been_replaced = False
 | 
			
		||||
 | 
			
		||||
    if "." in module_name:
 | 
			
		||||
        splits = module_name.split(".")
 | 
			
		||||
    parent_module = getattr(model, splits[0])
 | 
			
		||||
 | 
			
		||||
    if "lm_head" not in module_name:
 | 
			
		||||
        for split in splits[1:-2]:
 | 
			
		||||
            new_module = getattr(parent_module, split)
 | 
			
		||||
            parent_module = new_module
 | 
			
		||||
        module = getattr(parent_module, splits[-2])
 | 
			
		||||
        module_name = splits[-2]
 | 
			
		||||
    else:
 | 
			
		||||
        module = parent_module
 | 
			
		||||
        parent_module = model
 | 
			
		||||
        module_name = splits[0]
 | 
			
		||||
 | 
			
		||||
    if current_key_name is None:
 | 
			
		||||
        current_key_name = []
 | 
			
		||||
 | 
			
		||||
    if modules_to_not_convert is None:
 | 
			
		||||
        modules_to_not_convert = []
 | 
			
		||||
 | 
			
		||||
    is_linear, linear_args = is_linear_module(module)
 | 
			
		||||
    if is_linear and module_name not in modules_to_not_convert:
 | 
			
		||||
        # Check if the current key is not in the `modules_to_not_convert`
 | 
			
		||||
        if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
 | 
			
		||||
                module.weight.data.device.type != 'meta' and not isinstance(module, LowBitLinear)):
 | 
			
		||||
            in_features, out_features, mp_group = linear_args
 | 
			
		||||
            with init_empty_weights():
 | 
			
		||||
                new_linear = None
 | 
			
		||||
                is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
 | 
			
		||||
                is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
 | 
			
		||||
                is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
 | 
			
		||||
                if is_gptq or is_awq:
 | 
			
		||||
                    has_bias = module.bias is not None and module.bias.abs().sum() != 0
 | 
			
		||||
                    new_linear = LowBitLinear(
 | 
			
		||||
                        in_features,
 | 
			
		||||
                        out_features,
 | 
			
		||||
                        qtype=qtype,
 | 
			
		||||
                        bias=has_bias,
 | 
			
		||||
                        mp_group=mp_group,
 | 
			
		||||
                    )
 | 
			
		||||
                    device = module.qweight.data.device
 | 
			
		||||
                    invalidInputError(device.type != "meta",
 | 
			
		||||
                                      "converting from meta device is not supported")
 | 
			
		||||
                    # Copy the weights
 | 
			
		||||
                    paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq,
 | 
			
		||||
                                                               llm_awq=is_llm_awq),
 | 
			
		||||
                                             requires_grad=False,
 | 
			
		||||
                                             quantized=True,
 | 
			
		||||
                                             _shape=(out_features, in_features),
 | 
			
		||||
                                             convert_shape_only=convert_shape_only,
 | 
			
		||||
                                             qtype=qtype).to(device)
 | 
			
		||||
                    new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                    if has_bias:
 | 
			
		||||
                        new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                            .to(device)
 | 
			
		||||
                elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
 | 
			
		||||
                    new_linear = LowBitLinear(
 | 
			
		||||
                        in_features,
 | 
			
		||||
                        out_features,
 | 
			
		||||
                        qtype,
 | 
			
		||||
                        module.bias is not None,
 | 
			
		||||
                        mp_group=mp_group,
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                    device = module.weight.data.device
 | 
			
		||||
                    # Copy the weights
 | 
			
		||||
                    paramsLowBit = FP4Params(data=module.weight.data,
 | 
			
		||||
                                             requires_grad=False,
 | 
			
		||||
                                             quantized=False,
 | 
			
		||||
                                             _shape=None,
 | 
			
		||||
                                             convert_shape_only=convert_shape_only,
 | 
			
		||||
                                             qtype=qtype).to(device)
 | 
			
		||||
                    new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                    if module.bias is not None:
 | 
			
		||||
                        new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                            .to(device)
 | 
			
		||||
                elif qtype == ggml_tensor_qtype["fp16"]:
 | 
			
		||||
                    module.to(torch.float16)
 | 
			
		||||
                    new_linear = FP16Linear(
 | 
			
		||||
                        in_features,
 | 
			
		||||
                        out_features,
 | 
			
		||||
                        module.bias is not None,
 | 
			
		||||
                        mp_group=mp_group,
 | 
			
		||||
                    )
 | 
			
		||||
                    device = module.weight.data.device
 | 
			
		||||
                    from bigdl.llm.transformers.utils import get_ipex_version
 | 
			
		||||
                    if get_ipex_version() < "2.1.10+xpu":
 | 
			
		||||
                        new_linear._parameters['weight'] = nn.Parameter(module.weight)
 | 
			
		||||
                    else:
 | 
			
		||||
                        # only from 2.1, ipex provides matmul_bias_out
 | 
			
		||||
                        # so we need to transpose weight
 | 
			
		||||
                        new_weight = module.weight.transpose(0, 1).contiguous()
 | 
			
		||||
                        new_linear._parameters['weight'] = nn.Parameter(new_weight)
 | 
			
		||||
                        new_linear.weight_type = 2
 | 
			
		||||
                    if module.bias is not None:
 | 
			
		||||
                        new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                            .to(device)
 | 
			
		||||
                elif qtype == ggml_tensor_qtype["bf16"]:
 | 
			
		||||
                    module.to(torch.bfloat16)
 | 
			
		||||
                    new_linear = BF16Linear(
 | 
			
		||||
                        in_features,
 | 
			
		||||
                        out_features,
 | 
			
		||||
                        module.bias is not None,
 | 
			
		||||
                        mp_group=mp_group,
 | 
			
		||||
                    )
 | 
			
		||||
                    device = module.weight.data.device
 | 
			
		||||
                    # convert here
 | 
			
		||||
                    new_linear._parameters['weight'] = nn.Parameter(module.weight)
 | 
			
		||||
                    if module.bias is not None:
 | 
			
		||||
                        new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                            .to(device)
 | 
			
		||||
 | 
			
		||||
                if new_linear is not None:
 | 
			
		||||
                    if not module.training:
 | 
			
		||||
                        new_linear.eval()
 | 
			
		||||
                    parent_module._modules[module_name] = new_linear
 | 
			
		||||
                    has_been_replaced = True
 | 
			
		||||
                    # Force requires grad to False to avoid unexpected errors
 | 
			
		||||
                    parent_module._modules[module_name].requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
                    module.weight = None
 | 
			
		||||
 | 
			
		||||
    if has_been_replaced:
 | 
			
		||||
        if not (getattr(model, "quantization_method", None) == "gptq"):
 | 
			
		||||
            if torch_dtype == "auto":
 | 
			
		||||
                convert_bigdl_other_module(model, torch.float32)
 | 
			
		||||
            else:
 | 
			
		||||
                convert_bigdl_other_module(model, torch_dtype)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _optimize_pre(model):
 | 
			
		||||
    from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
    # All huggingface format models are inherited from `PreTrainedModel`
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -22,9 +22,12 @@ from tempfile import NamedTemporaryFile
 | 
			
		|||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
 | 
			
		||||
 | 
			
		||||
from ..gguf import GGUFFileLoader
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
 | 
			
		||||
def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
 | 
			
		||||
                    low_bit='sym_int4'):
 | 
			
		||||
    config = loader.config
 | 
			
		||||
 | 
			
		||||
    llama_config = LlamaConfig(
 | 
			
		||||
| 
						 | 
				
			
			@ -44,42 +47,40 @@ def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
 | 
			
		|||
        pretraining_tp=1,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    ckpt = loader.tensors(dtype)
 | 
			
		||||
    qtype = ggml_tensor_qtype[low_bit]
 | 
			
		||||
    n_head = config['llama.attention.head_count']
 | 
			
		||||
    n_head_kv = config['llama.attention.head_count_kv']
 | 
			
		||||
    ckpt = restore_llama_weight(ckpt, n_head, n_head_kv)
 | 
			
		||||
 | 
			
		||||
    state_dict = {}
 | 
			
		||||
    state_dict['model.embed_tokens.weight'] = ckpt['token_embd.weight']
 | 
			
		||||
    state_dict['model.norm.weight'] = ckpt['output_norm.weight']
 | 
			
		||||
    state_dict['lm_head.weight'] = ckpt['output.weight']
 | 
			
		||||
    for i in range(config['llama.block_count']):
 | 
			
		||||
        state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = \
 | 
			
		||||
            ckpt[f'blk.{i}.attn_q.weight']
 | 
			
		||||
        state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = \
 | 
			
		||||
            ckpt[f'blk.{i}.attn_k.weight']
 | 
			
		||||
        state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = \
 | 
			
		||||
            ckpt[f'blk.{i}.attn_v.weight']
 | 
			
		||||
        state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = \
 | 
			
		||||
            ckpt[f'blk.{i}.attn_output.weight']
 | 
			
		||||
        state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = \
 | 
			
		||||
            ckpt[f'blk.{i}.ffn_gate.weight']
 | 
			
		||||
        state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = \
 | 
			
		||||
            ckpt[f'blk.{i}.ffn_up.weight']
 | 
			
		||||
        state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = \
 | 
			
		||||
            ckpt[f'blk.{i}.ffn_down.weight']
 | 
			
		||||
        state_dict[f'model.layers.{i}.input_layernorm.weight'] = \
 | 
			
		||||
            ckpt[f'blk.{i}.attn_norm.weight']
 | 
			
		||||
        state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = \
 | 
			
		||||
            ckpt[f'blk.{i}.ffn_norm.weight']
 | 
			
		||||
 | 
			
		||||
    with init_empty_weights():
 | 
			
		||||
        model = LlamaForCausalLM(llama_config)
 | 
			
		||||
 | 
			
		||||
    for name, weight in state_dict.items():
 | 
			
		||||
        set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype)
 | 
			
		||||
 | 
			
		||||
    model = model.cpu()
 | 
			
		||||
    def process_llama(name, tensor):
 | 
			
		||||
        nonlocal model
 | 
			
		||||
        module_name = get_llama_module_name(name)
 | 
			
		||||
        if 'q_proj' in module_name:
 | 
			
		||||
            # gguf weight needs to reshape for q_proj
 | 
			
		||||
            head, hd_size = tensor.shape[0], tensor.shape[1:]
 | 
			
		||||
            set_module_tensor_to_device(model, module_name, "cpu",
 | 
			
		||||
                                        tensor.reshape(n_head, head // n_head // 2, 2, *hd_size)
 | 
			
		||||
                                              .swapaxes(1, 2)
 | 
			
		||||
                                              .reshape(tensor.shape),
 | 
			
		||||
                                        dtype=dtype)
 | 
			
		||||
        elif 'k_proj' in module_name:
 | 
			
		||||
            # gguf weight needs to reshape for k_proj
 | 
			
		||||
            head, hd_size = tensor.shape[0], tensor.shape[1:]
 | 
			
		||||
            set_module_tensor_to_device(model, module_name, "cpu",
 | 
			
		||||
                                        tensor.reshape(n_head_kv,
 | 
			
		||||
                                                       head // n_head_kv // 2,
 | 
			
		||||
                                                       2,
 | 
			
		||||
                                                       *hd_size)
 | 
			
		||||
                                              .swapaxes(1, 2)
 | 
			
		||||
                                              .reshape(tensor.shape),
 | 
			
		||||
                                        dtype=dtype)
 | 
			
		||||
        else:
 | 
			
		||||
            set_module_tensor_to_device(model, module_name, "cpu", tensor, dtype=dtype)
 | 
			
		||||
        model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name)
 | 
			
		||||
    tensor_loader = loader.tensor_loader
 | 
			
		||||
    tensor_loader.load_while_process(process_llama)
 | 
			
		||||
 | 
			
		||||
    # see https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
 | 
			
		||||
    from transformers.convert_slow_tokenizer import import_protobuf
 | 
			
		||||
| 
						 | 
				
			
			@ -100,18 +101,29 @@ def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
 | 
			
		|||
    return model, tokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def restore_llama_weight(ckpt: dict, n_head: int, n_head_kv: int):
 | 
			
		||||
    # see https://github.com/ggerganov/llama.cpp/blob
 | 
			
		||||
    # /3e73d31d9cc0232882ce61c64742aff3ecfec416/convert.py#L978
 | 
			
		||||
 | 
			
		||||
    for name, weight in ckpt.items():
 | 
			
		||||
        head, hd_size = weight.shape[0], weight.shape[1:]
 | 
			
		||||
        if name.endswith("attn_q.weight"):
 | 
			
		||||
            ckpt[name] = (weight.reshape(n_head, head // n_head // 2, 2, *hd_size)
 | 
			
		||||
                                .swapaxes(1, 2)
 | 
			
		||||
                                .reshape(weight.shape))
 | 
			
		||||
        elif name.endswith("attn_k.weight"):
 | 
			
		||||
            ckpt[name] = (weight.reshape(n_head_kv, head // n_head_kv // 2, 2, *hd_size)
 | 
			
		||||
                                .swapaxes(1, 2)
 | 
			
		||||
                                .reshape(weight.shape))
 | 
			
		||||
    return ckpt
 | 
			
		||||
def get_llama_module_name(name):
 | 
			
		||||
    if name == 'token_embd.weight':
 | 
			
		||||
        return 'model.embed_tokens.weight'
 | 
			
		||||
    if name == 'output_norm.weight':
 | 
			
		||||
        return 'model.norm.weight'
 | 
			
		||||
    if name == 'output.weight':
 | 
			
		||||
        return 'lm_head.weight'
 | 
			
		||||
    layer_id = name.split('.')[1]
 | 
			
		||||
    if 'attn_q' in name:
 | 
			
		||||
        return f'model.layers.{layer_id}.self_attn.q_proj.weight'
 | 
			
		||||
    if 'attn_k' in name:
 | 
			
		||||
        return f'model.layers.{layer_id}.self_attn.k_proj.weight'
 | 
			
		||||
    if 'attn_v' in name:
 | 
			
		||||
        return f'model.layers.{layer_id}.self_attn.v_proj.weight'
 | 
			
		||||
    if 'attn_output' in name:
 | 
			
		||||
        return f'model.layers.{layer_id}.self_attn.o_proj.weight'
 | 
			
		||||
    if 'ffn_gate' in name:
 | 
			
		||||
        return f'model.layers.{layer_id}.mlp.gate_proj.weight'
 | 
			
		||||
    if 'ffn_up' in name:
 | 
			
		||||
        return f'model.layers.{layer_id}.mlp.up_proj.weight'
 | 
			
		||||
    if 'ffn_down' in name:
 | 
			
		||||
        return f'model.layers.{layer_id}.mlp.down_proj.weight'
 | 
			
		||||
    if 'attn_norm' in name:
 | 
			
		||||
        return f'model.layers.{layer_id}.input_layernorm.weight'
 | 
			
		||||
    if 'ffn_norm' in name:
 | 
			
		||||
        return f'model.layers.{layer_id}.post_attention_layernorm.weight'
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue