From b909c5c9c2cf942af3cc354afbaa3df7b3b13cbf Mon Sep 17 00:00:00 2001 From: Shaojun Liu <61072813+liu-shaojun@users.noreply.github.com> Date: Tue, 16 Jan 2024 18:54:39 +0800 Subject: [PATCH] GGUF load memory optimization (#9913) * block-wise * convert linear for module * revert * Fix PEP8 checks Error --- .../llm/src/bigdl/llm/transformers/convert.py | 139 ++++++++++++++++++ .../llm/transformers/gguf/models/llama.py | 104 +++++++------ 2 files changed, 197 insertions(+), 46 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 0ef053e4..ac05a28f 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -327,6 +327,145 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, return model, has_been_replaced +def replace_with_low_bit_linear_for_module(model, qtype, module_name=None, + modules_to_not_convert=None, current_key_name=None, + convert_shape_only=False, torch_dtype="auto"): + from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \ + FP16Linear, BF16Linear + has_been_replaced = False + + if "." in module_name: + splits = module_name.split(".") + parent_module = getattr(model, splits[0]) + + if "lm_head" not in module_name: + for split in splits[1:-2]: + new_module = getattr(parent_module, split) + parent_module = new_module + module = getattr(parent_module, splits[-2]) + module_name = splits[-2] + else: + module = parent_module + parent_module = model + module_name = splits[0] + + if current_key_name is None: + current_key_name = [] + + if modules_to_not_convert is None: + modules_to_not_convert = [] + + is_linear, linear_args = is_linear_module(module) + if is_linear and module_name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and + module.weight.data.device.type != 'meta' and not isinstance(module, LowBitLinear)): + in_features, out_features, mp_group = linear_args + with init_empty_weights(): + new_linear = None + is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld) + is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM) + is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ + if is_gptq or is_awq: + has_bias = module.bias is not None and module.bias.abs().sum() != 0 + new_linear = LowBitLinear( + in_features, + out_features, + qtype=qtype, + bias=has_bias, + mp_group=mp_group, + ) + device = module.qweight.data.device + invalidInputError(device.type != "meta", + "converting from meta device is not supported") + # Copy the weights + paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq, + llm_awq=is_llm_awq), + requires_grad=False, + quantized=True, + _shape=(out_features, in_features), + convert_shape_only=convert_shape_only, + qtype=qtype).to(device) + new_linear._parameters['weight'] = paramsLowBit + if has_bias: + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device) + elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]: + new_linear = LowBitLinear( + in_features, + out_features, + qtype, + module.bias is not None, + mp_group=mp_group, + ) + + device = module.weight.data.device + # Copy the weights + paramsLowBit = FP4Params(data=module.weight.data, + requires_grad=False, + quantized=False, + _shape=None, + convert_shape_only=convert_shape_only, + qtype=qtype).to(device) + new_linear._parameters['weight'] = paramsLowBit + if module.bias is not None: + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device) + elif qtype == ggml_tensor_qtype["fp16"]: + module.to(torch.float16) + new_linear = FP16Linear( + in_features, + out_features, + module.bias is not None, + mp_group=mp_group, + ) + device = module.weight.data.device + from bigdl.llm.transformers.utils import get_ipex_version + if get_ipex_version() < "2.1.10+xpu": + new_linear._parameters['weight'] = nn.Parameter(module.weight) + else: + # only from 2.1, ipex provides matmul_bias_out + # so we need to transpose weight + new_weight = module.weight.transpose(0, 1).contiguous() + new_linear._parameters['weight'] = nn.Parameter(new_weight) + new_linear.weight_type = 2 + if module.bias is not None: + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device) + elif qtype == ggml_tensor_qtype["bf16"]: + module.to(torch.bfloat16) + new_linear = BF16Linear( + in_features, + out_features, + module.bias is not None, + mp_group=mp_group, + ) + device = module.weight.data.device + # convert here + new_linear._parameters['weight'] = nn.Parameter(module.weight) + if module.bias is not None: + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device) + + if new_linear is not None: + if not module.training: + new_linear.eval() + parent_module._modules[module_name] = new_linear + has_been_replaced = True + # Force requires grad to False to avoid unexpected errors + parent_module._modules[module_name].requires_grad_(False) + + module.weight = None + + if has_been_replaced: + if not (getattr(model, "quantization_method", None) == "gptq"): + if torch_dtype == "auto": + convert_bigdl_other_module(model, torch.float32) + else: + convert_bigdl_other_module(model, torch_dtype) + return model + + def _optimize_pre(model): from transformers.modeling_utils import PreTrainedModel # All huggingface format models are inherited from `PreTrainedModel` diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/llama.py b/python/llm/src/bigdl/llm/transformers/gguf/models/llama.py index f35551f9..f40eeab3 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/llama.py @@ -22,9 +22,12 @@ from tempfile import NamedTemporaryFile from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer from ..gguf import GGUFFileLoader +from bigdl.llm.ggml.quantize import ggml_tensor_qtype +from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module -def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): +def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float, + low_bit='sym_int4'): config = loader.config llama_config = LlamaConfig( @@ -44,42 +47,40 @@ def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): pretraining_tp=1, ) - ckpt = loader.tensors(dtype) + qtype = ggml_tensor_qtype[low_bit] n_head = config['llama.attention.head_count'] n_head_kv = config['llama.attention.head_count_kv'] - ckpt = restore_llama_weight(ckpt, n_head, n_head_kv) - - state_dict = {} - state_dict['model.embed_tokens.weight'] = ckpt['token_embd.weight'] - state_dict['model.norm.weight'] = ckpt['output_norm.weight'] - state_dict['lm_head.weight'] = ckpt['output.weight'] - for i in range(config['llama.block_count']): - state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = \ - ckpt[f'blk.{i}.attn_q.weight'] - state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = \ - ckpt[f'blk.{i}.attn_k.weight'] - state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = \ - ckpt[f'blk.{i}.attn_v.weight'] - state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = \ - ckpt[f'blk.{i}.attn_output.weight'] - state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = \ - ckpt[f'blk.{i}.ffn_gate.weight'] - state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = \ - ckpt[f'blk.{i}.ffn_up.weight'] - state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = \ - ckpt[f'blk.{i}.ffn_down.weight'] - state_dict[f'model.layers.{i}.input_layernorm.weight'] = \ - ckpt[f'blk.{i}.attn_norm.weight'] - state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = \ - ckpt[f'blk.{i}.ffn_norm.weight'] with init_empty_weights(): model = LlamaForCausalLM(llama_config) - for name, weight in state_dict.items(): - set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype) - - model = model.cpu() + def process_llama(name, tensor): + nonlocal model + module_name = get_llama_module_name(name) + if 'q_proj' in module_name: + # gguf weight needs to reshape for q_proj + head, hd_size = tensor.shape[0], tensor.shape[1:] + set_module_tensor_to_device(model, module_name, "cpu", + tensor.reshape(n_head, head // n_head // 2, 2, *hd_size) + .swapaxes(1, 2) + .reshape(tensor.shape), + dtype=dtype) + elif 'k_proj' in module_name: + # gguf weight needs to reshape for k_proj + head, hd_size = tensor.shape[0], tensor.shape[1:] + set_module_tensor_to_device(model, module_name, "cpu", + tensor.reshape(n_head_kv, + head // n_head_kv // 2, + 2, + *hd_size) + .swapaxes(1, 2) + .reshape(tensor.shape), + dtype=dtype) + else: + set_module_tensor_to_device(model, module_name, "cpu", tensor, dtype=dtype) + model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name) + tensor_loader = loader.tensor_loader + tensor_loader.load_while_process(process_llama) # see https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto from transformers.convert_slow_tokenizer import import_protobuf @@ -100,18 +101,29 @@ def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): return model, tokenizer -def restore_llama_weight(ckpt: dict, n_head: int, n_head_kv: int): - # see https://github.com/ggerganov/llama.cpp/blob - # /3e73d31d9cc0232882ce61c64742aff3ecfec416/convert.py#L978 - - for name, weight in ckpt.items(): - head, hd_size = weight.shape[0], weight.shape[1:] - if name.endswith("attn_q.weight"): - ckpt[name] = (weight.reshape(n_head, head // n_head // 2, 2, *hd_size) - .swapaxes(1, 2) - .reshape(weight.shape)) - elif name.endswith("attn_k.weight"): - ckpt[name] = (weight.reshape(n_head_kv, head // n_head_kv // 2, 2, *hd_size) - .swapaxes(1, 2) - .reshape(weight.shape)) - return ckpt +def get_llama_module_name(name): + if name == 'token_embd.weight': + return 'model.embed_tokens.weight' + if name == 'output_norm.weight': + return 'model.norm.weight' + if name == 'output.weight': + return 'lm_head.weight' + layer_id = name.split('.')[1] + if 'attn_q' in name: + return f'model.layers.{layer_id}.self_attn.q_proj.weight' + if 'attn_k' in name: + return f'model.layers.{layer_id}.self_attn.k_proj.weight' + if 'attn_v' in name: + return f'model.layers.{layer_id}.self_attn.v_proj.weight' + if 'attn_output' in name: + return f'model.layers.{layer_id}.self_attn.o_proj.weight' + if 'ffn_gate' in name: + return f'model.layers.{layer_id}.mlp.gate_proj.weight' + if 'ffn_up' in name: + return f'model.layers.{layer_id}.mlp.up_proj.weight' + if 'ffn_down' in name: + return f'model.layers.{layer_id}.mlp.down_proj.weight' + if 'attn_norm' in name: + return f'model.layers.{layer_id}.input_layernorm.weight' + if 'ffn_norm' in name: + return f'model.layers.{layer_id}.post_attention_layernorm.weight'