GGUF load memory optimization (#9913)
* block-wise * convert linear for module * revert * Fix PEP8 checks Error
This commit is contained in:
parent
8643b62521
commit
b909c5c9c2
2 changed files with 197 additions and 46 deletions
|
|
@ -327,6 +327,145 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
return model, has_been_replaced
|
return model, has_been_replaced
|
||||||
|
|
||||||
|
|
||||||
|
def replace_with_low_bit_linear_for_module(model, qtype, module_name=None,
|
||||||
|
modules_to_not_convert=None, current_key_name=None,
|
||||||
|
convert_shape_only=False, torch_dtype="auto"):
|
||||||
|
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
||||||
|
FP16Linear, BF16Linear
|
||||||
|
has_been_replaced = False
|
||||||
|
|
||||||
|
if "." in module_name:
|
||||||
|
splits = module_name.split(".")
|
||||||
|
parent_module = getattr(model, splits[0])
|
||||||
|
|
||||||
|
if "lm_head" not in module_name:
|
||||||
|
for split in splits[1:-2]:
|
||||||
|
new_module = getattr(parent_module, split)
|
||||||
|
parent_module = new_module
|
||||||
|
module = getattr(parent_module, splits[-2])
|
||||||
|
module_name = splits[-2]
|
||||||
|
else:
|
||||||
|
module = parent_module
|
||||||
|
parent_module = model
|
||||||
|
module_name = splits[0]
|
||||||
|
|
||||||
|
if current_key_name is None:
|
||||||
|
current_key_name = []
|
||||||
|
|
||||||
|
if modules_to_not_convert is None:
|
||||||
|
modules_to_not_convert = []
|
||||||
|
|
||||||
|
is_linear, linear_args = is_linear_module(module)
|
||||||
|
if is_linear and module_name not in modules_to_not_convert:
|
||||||
|
# Check if the current key is not in the `modules_to_not_convert`
|
||||||
|
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
|
||||||
|
module.weight.data.device.type != 'meta' and not isinstance(module, LowBitLinear)):
|
||||||
|
in_features, out_features, mp_group = linear_args
|
||||||
|
with init_empty_weights():
|
||||||
|
new_linear = None
|
||||||
|
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
|
||||||
|
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
||||||
|
is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
|
||||||
|
if is_gptq or is_awq:
|
||||||
|
has_bias = module.bias is not None and module.bias.abs().sum() != 0
|
||||||
|
new_linear = LowBitLinear(
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
qtype=qtype,
|
||||||
|
bias=has_bias,
|
||||||
|
mp_group=mp_group,
|
||||||
|
)
|
||||||
|
device = module.qweight.data.device
|
||||||
|
invalidInputError(device.type != "meta",
|
||||||
|
"converting from meta device is not supported")
|
||||||
|
# Copy the weights
|
||||||
|
paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq,
|
||||||
|
llm_awq=is_llm_awq),
|
||||||
|
requires_grad=False,
|
||||||
|
quantized=True,
|
||||||
|
_shape=(out_features, in_features),
|
||||||
|
convert_shape_only=convert_shape_only,
|
||||||
|
qtype=qtype).to(device)
|
||||||
|
new_linear._parameters['weight'] = paramsLowBit
|
||||||
|
if has_bias:
|
||||||
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
|
.to(device)
|
||||||
|
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
|
||||||
|
new_linear = LowBitLinear(
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
qtype,
|
||||||
|
module.bias is not None,
|
||||||
|
mp_group=mp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = module.weight.data.device
|
||||||
|
# Copy the weights
|
||||||
|
paramsLowBit = FP4Params(data=module.weight.data,
|
||||||
|
requires_grad=False,
|
||||||
|
quantized=False,
|
||||||
|
_shape=None,
|
||||||
|
convert_shape_only=convert_shape_only,
|
||||||
|
qtype=qtype).to(device)
|
||||||
|
new_linear._parameters['weight'] = paramsLowBit
|
||||||
|
if module.bias is not None:
|
||||||
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
|
.to(device)
|
||||||
|
elif qtype == ggml_tensor_qtype["fp16"]:
|
||||||
|
module.to(torch.float16)
|
||||||
|
new_linear = FP16Linear(
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
module.bias is not None,
|
||||||
|
mp_group=mp_group,
|
||||||
|
)
|
||||||
|
device = module.weight.data.device
|
||||||
|
from bigdl.llm.transformers.utils import get_ipex_version
|
||||||
|
if get_ipex_version() < "2.1.10+xpu":
|
||||||
|
new_linear._parameters['weight'] = nn.Parameter(module.weight)
|
||||||
|
else:
|
||||||
|
# only from 2.1, ipex provides matmul_bias_out
|
||||||
|
# so we need to transpose weight
|
||||||
|
new_weight = module.weight.transpose(0, 1).contiguous()
|
||||||
|
new_linear._parameters['weight'] = nn.Parameter(new_weight)
|
||||||
|
new_linear.weight_type = 2
|
||||||
|
if module.bias is not None:
|
||||||
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
|
.to(device)
|
||||||
|
elif qtype == ggml_tensor_qtype["bf16"]:
|
||||||
|
module.to(torch.bfloat16)
|
||||||
|
new_linear = BF16Linear(
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
module.bias is not None,
|
||||||
|
mp_group=mp_group,
|
||||||
|
)
|
||||||
|
device = module.weight.data.device
|
||||||
|
# convert here
|
||||||
|
new_linear._parameters['weight'] = nn.Parameter(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
|
.to(device)
|
||||||
|
|
||||||
|
if new_linear is not None:
|
||||||
|
if not module.training:
|
||||||
|
new_linear.eval()
|
||||||
|
parent_module._modules[module_name] = new_linear
|
||||||
|
has_been_replaced = True
|
||||||
|
# Force requires grad to False to avoid unexpected errors
|
||||||
|
parent_module._modules[module_name].requires_grad_(False)
|
||||||
|
|
||||||
|
module.weight = None
|
||||||
|
|
||||||
|
if has_been_replaced:
|
||||||
|
if not (getattr(model, "quantization_method", None) == "gptq"):
|
||||||
|
if torch_dtype == "auto":
|
||||||
|
convert_bigdl_other_module(model, torch.float32)
|
||||||
|
else:
|
||||||
|
convert_bigdl_other_module(model, torch_dtype)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _optimize_pre(model):
|
def _optimize_pre(model):
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
# All huggingface format models are inherited from `PreTrainedModel`
|
# All huggingface format models are inherited from `PreTrainedModel`
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,12 @@ from tempfile import NamedTemporaryFile
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
from ..gguf import GGUFFileLoader
|
from ..gguf import GGUFFileLoader
|
||||||
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module
|
||||||
|
|
||||||
|
|
||||||
def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
|
||||||
|
low_bit='sym_int4'):
|
||||||
config = loader.config
|
config = loader.config
|
||||||
|
|
||||||
llama_config = LlamaConfig(
|
llama_config = LlamaConfig(
|
||||||
|
|
@ -44,42 +47,40 @@ def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
||||||
pretraining_tp=1,
|
pretraining_tp=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
ckpt = loader.tensors(dtype)
|
qtype = ggml_tensor_qtype[low_bit]
|
||||||
n_head = config['llama.attention.head_count']
|
n_head = config['llama.attention.head_count']
|
||||||
n_head_kv = config['llama.attention.head_count_kv']
|
n_head_kv = config['llama.attention.head_count_kv']
|
||||||
ckpt = restore_llama_weight(ckpt, n_head, n_head_kv)
|
|
||||||
|
|
||||||
state_dict = {}
|
|
||||||
state_dict['model.embed_tokens.weight'] = ckpt['token_embd.weight']
|
|
||||||
state_dict['model.norm.weight'] = ckpt['output_norm.weight']
|
|
||||||
state_dict['lm_head.weight'] = ckpt['output.weight']
|
|
||||||
for i in range(config['llama.block_count']):
|
|
||||||
state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.attn_q.weight']
|
|
||||||
state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.attn_k.weight']
|
|
||||||
state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.attn_v.weight']
|
|
||||||
state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.attn_output.weight']
|
|
||||||
state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.ffn_gate.weight']
|
|
||||||
state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.ffn_up.weight']
|
|
||||||
state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.ffn_down.weight']
|
|
||||||
state_dict[f'model.layers.{i}.input_layernorm.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.attn_norm.weight']
|
|
||||||
state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.ffn_norm.weight']
|
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = LlamaForCausalLM(llama_config)
|
model = LlamaForCausalLM(llama_config)
|
||||||
|
|
||||||
for name, weight in state_dict.items():
|
def process_llama(name, tensor):
|
||||||
set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype)
|
nonlocal model
|
||||||
|
module_name = get_llama_module_name(name)
|
||||||
model = model.cpu()
|
if 'q_proj' in module_name:
|
||||||
|
# gguf weight needs to reshape for q_proj
|
||||||
|
head, hd_size = tensor.shape[0], tensor.shape[1:]
|
||||||
|
set_module_tensor_to_device(model, module_name, "cpu",
|
||||||
|
tensor.reshape(n_head, head // n_head // 2, 2, *hd_size)
|
||||||
|
.swapaxes(1, 2)
|
||||||
|
.reshape(tensor.shape),
|
||||||
|
dtype=dtype)
|
||||||
|
elif 'k_proj' in module_name:
|
||||||
|
# gguf weight needs to reshape for k_proj
|
||||||
|
head, hd_size = tensor.shape[0], tensor.shape[1:]
|
||||||
|
set_module_tensor_to_device(model, module_name, "cpu",
|
||||||
|
tensor.reshape(n_head_kv,
|
||||||
|
head // n_head_kv // 2,
|
||||||
|
2,
|
||||||
|
*hd_size)
|
||||||
|
.swapaxes(1, 2)
|
||||||
|
.reshape(tensor.shape),
|
||||||
|
dtype=dtype)
|
||||||
|
else:
|
||||||
|
set_module_tensor_to_device(model, module_name, "cpu", tensor, dtype=dtype)
|
||||||
|
model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name)
|
||||||
|
tensor_loader = loader.tensor_loader
|
||||||
|
tensor_loader.load_while_process(process_llama)
|
||||||
|
|
||||||
# see https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
|
# see https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
|
||||||
from transformers.convert_slow_tokenizer import import_protobuf
|
from transformers.convert_slow_tokenizer import import_protobuf
|
||||||
|
|
@ -100,18 +101,29 @@ def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def restore_llama_weight(ckpt: dict, n_head: int, n_head_kv: int):
|
def get_llama_module_name(name):
|
||||||
# see https://github.com/ggerganov/llama.cpp/blob
|
if name == 'token_embd.weight':
|
||||||
# /3e73d31d9cc0232882ce61c64742aff3ecfec416/convert.py#L978
|
return 'model.embed_tokens.weight'
|
||||||
|
if name == 'output_norm.weight':
|
||||||
for name, weight in ckpt.items():
|
return 'model.norm.weight'
|
||||||
head, hd_size = weight.shape[0], weight.shape[1:]
|
if name == 'output.weight':
|
||||||
if name.endswith("attn_q.weight"):
|
return 'lm_head.weight'
|
||||||
ckpt[name] = (weight.reshape(n_head, head // n_head // 2, 2, *hd_size)
|
layer_id = name.split('.')[1]
|
||||||
.swapaxes(1, 2)
|
if 'attn_q' in name:
|
||||||
.reshape(weight.shape))
|
return f'model.layers.{layer_id}.self_attn.q_proj.weight'
|
||||||
elif name.endswith("attn_k.weight"):
|
if 'attn_k' in name:
|
||||||
ckpt[name] = (weight.reshape(n_head_kv, head // n_head_kv // 2, 2, *hd_size)
|
return f'model.layers.{layer_id}.self_attn.k_proj.weight'
|
||||||
.swapaxes(1, 2)
|
if 'attn_v' in name:
|
||||||
.reshape(weight.shape))
|
return f'model.layers.{layer_id}.self_attn.v_proj.weight'
|
||||||
return ckpt
|
if 'attn_output' in name:
|
||||||
|
return f'model.layers.{layer_id}.self_attn.o_proj.weight'
|
||||||
|
if 'ffn_gate' in name:
|
||||||
|
return f'model.layers.{layer_id}.mlp.gate_proj.weight'
|
||||||
|
if 'ffn_up' in name:
|
||||||
|
return f'model.layers.{layer_id}.mlp.up_proj.weight'
|
||||||
|
if 'ffn_down' in name:
|
||||||
|
return f'model.layers.{layer_id}.mlp.down_proj.weight'
|
||||||
|
if 'attn_norm' in name:
|
||||||
|
return f'model.layers.{layer_id}.input_layernorm.weight'
|
||||||
|
if 'ffn_norm' in name:
|
||||||
|
return f'model.layers.{layer_id}.post_attention_layernorm.weight'
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue