LLM: support iq2 for mixtral (#10191)

* support name mapping for mixtral

* support mixtral mixed quantization

* fix style

* fix
This commit is contained in:
Ruonan Wang 2024-02-21 16:00:29 +08:00 committed by GitHub
parent 079f2011ea
commit f7c96b19ef
2 changed files with 66 additions and 22 deletions

View file

@ -191,7 +191,8 @@ def convert_gptq(module, awq=False, llm_awq=False):
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name=None, convert_shape_only=False,
cpu_embedding=False, prefix_name='',
imatrix_data=None, embedding_qtype=None):
imatrix_data=None, embedding_qtype=None,
model_type=None):
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
FP16Linear, BF16Linear
from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
@ -251,7 +252,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
)
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
full_module_name,
imatrix_data)
imatrix_data,
model_type)
device = module.weight.data.device
# Copy the weights
paramsLowBit = FP4Params(data=module.weight.data,
@ -361,7 +363,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
cpu_embedding,
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype
embedding_qtype=embedding_qtype,
model_type=model_type
)
has_been_replaced = _flag or has_been_replaced
return model, has_been_replaced
@ -558,11 +561,17 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
if optimize_model:
model = _optimize_pre(model)
# mixed quantization needs model_type to choose custom quantization strategy
if hasattr(model, "config"):
model_type = getattr(model.config, "model_type", None)
else:
model_type = None
model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert,
None, convert_shape_only, cpu_embedding,
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype
embedding_qtype=embedding_qtype,
model_type=model_type
)
if not has_been_replaced:
warnings.warn(

View file

@ -197,16 +197,24 @@ def load_imatrix_data(imatrix_file):
cur_len = imatrix.read(4)
cur_len = int.from_bytes(cur_len, 'little')
cur_name = str(imatrix.read(cur_len), encoding='utf-8')
# original cur_name looks like blk.14.attn_output.weight for llama
# TODO: how to better aligned and generalize
# cur_name looks like blk.14.attn_output.weight for llama / mistral,
# cur_name looks like blk.0.ffn_down.3.weight for mixtral and
# blk.17.ffn_gate_inp.weight for mixtral
name_list = cur_name.split('.')
layer = name_list[1]
module_name = name_list[2]
if 'ffn' in module_name:
exp_id = None
if 'ffn' in module_name and len(name_list) == 4:
module_name = module_name[4:] # from ffn_gate to gate
elif 'ffn' in module_name and len(name_list) == 5:
# mixtral's mlp layer
module_name = module_name[4:]
exp_id = name_list[3]
elif 'attn' in module_name:
module_name = module_name[5] # from attn_k to k, attn_output to o
module_name = layer + '_' + module_name
if exp_id is not None:
module_name += '_' + exp_id
ncall = imatrix.read(4)
ncall = int.from_bytes(ncall, 'little')
nval = imatrix.read(4)
@ -225,32 +233,59 @@ def load_imatrix_data(imatrix_file):
def module_name_process(full_module_name):
# full name maybe model.layers.31.self_attn.o_proj
# TODO: how to better aligned and generalize
module_name = full_module_name.split('.')
if len(module_name) == 5:
layer = module_name[2]
cur_module = module_name[-1][:-5]
# full name maybe model.layers.31.self_attn.o_proj for llama/mistral
# full name maybe model.layers.0.block_sparse_moe.gate or
# model.layers.0.block_sparse_moe.experts.0.w1 for mixtral
module_name_list = full_module_name.split('.')
if len(module_name_list) >= 5:
super_module_name = module_name_list[3]
else:
super_module_name = None
exp_id = None
if super_module_name == 'block_sparse_moe':
# handle mixtral moe here
moe_mapping = {"w1": "gate", "w2": "down", "w3": "up"}
layer = module_name_list[2]
if len(module_name_list) == 5 and module_name_list[-1] == 'gate':
cur_module = 'gate_inp' # mapping with imatrix
elif len(module_name_list) == 7:
exp_id = module_name_list[-2]
cur_module = module_name_list[-1]
cur_module = moe_mapping[cur_module]
new_module_name = '_'.join([layer, cur_module])
elif len(module_name) == 1:
new_module_name = module_name[0]
layer = None
cur_module = None
if exp_id is not None:
new_module_name += '_' + exp_id
else:
if len(module_name_list) == 5:
layer = module_name_list[2]
cur_module = module_name_list[-1][:-5]
new_module_name = '_'.join([layer, cur_module])
elif len(module_name_list) == 1:
new_module_name = module_name_list[0]
layer = None
cur_module = None
return new_module_name, layer, cur_module
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_type=None):
cur_qtype = qtype
if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]:
# For quantization which needs importance matrix
new_module_name, layer, cur_module = module_name_process(full_module_name)
# custom mixed quantization strategy
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
cur_qtype = ggml_tensor_qtype['q2_k']
if model_type == "mixtral":
if cur_module == 'v':
# llama.cpp use q4_K here
cur_qtype = ggml_tensor_qtype['sym_int4']
elif cur_module == 'down' and int(layer) in [0, 1, 2, 3]:
cur_qtype = ggml_tensor_qtype['q2_k']
else:
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
cur_qtype = ggml_tensor_qtype['q2_k']
if imatrix_data is not None and new_module_name in imatrix_data:
cur_imatrix = imatrix_data[new_module_name]
else:
# if no imatrix is available, use fp8 for lm_head
# if no imatrix is available, use sym_int8 for lm_head
cur_imatrix = None
if new_module_name == 'lm_head':
cur_qtype = ggml_tensor_qtype['sym_int8']
@ -263,7 +298,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
if imatrix_data is not None and new_module_name in imatrix_data:
cur_imatrix = imatrix_data[new_module_name]
else:
# if no imatrix is available, use fp8 for lm_head
# if no imatrix is available, use sym_int8 for lm_head
cur_imatrix = None
if new_module_name == 'lm_head':
cur_qtype = ggml_tensor_qtype['sym_int8']