LLM: support iq2 for mixtral (#10191)
* support name mapping for mixtral * support mixtral mixed quantization * fix style * fix
This commit is contained in:
parent
079f2011ea
commit
f7c96b19ef
2 changed files with 66 additions and 22 deletions
|
|
@ -191,7 +191,8 @@ def convert_gptq(module, awq=False, llm_awq=False):
|
||||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
current_key_name=None, convert_shape_only=False,
|
current_key_name=None, convert_shape_only=False,
|
||||||
cpu_embedding=False, prefix_name='',
|
cpu_embedding=False, prefix_name='',
|
||||||
imatrix_data=None, embedding_qtype=None):
|
imatrix_data=None, embedding_qtype=None,
|
||||||
|
model_type=None):
|
||||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
||||||
FP16Linear, BF16Linear
|
FP16Linear, BF16Linear
|
||||||
from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
|
from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
|
||||||
|
|
@ -251,7 +252,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
)
|
)
|
||||||
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
|
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
|
||||||
full_module_name,
|
full_module_name,
|
||||||
imatrix_data)
|
imatrix_data,
|
||||||
|
model_type)
|
||||||
device = module.weight.data.device
|
device = module.weight.data.device
|
||||||
# Copy the weights
|
# Copy the weights
|
||||||
paramsLowBit = FP4Params(data=module.weight.data,
|
paramsLowBit = FP4Params(data=module.weight.data,
|
||||||
|
|
@ -361,7 +363,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
cpu_embedding,
|
cpu_embedding,
|
||||||
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
|
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
|
||||||
imatrix_data=imatrix_data,
|
imatrix_data=imatrix_data,
|
||||||
embedding_qtype=embedding_qtype
|
embedding_qtype=embedding_qtype,
|
||||||
|
model_type=model_type
|
||||||
)
|
)
|
||||||
has_been_replaced = _flag or has_been_replaced
|
has_been_replaced = _flag or has_been_replaced
|
||||||
return model, has_been_replaced
|
return model, has_been_replaced
|
||||||
|
|
@ -558,11 +561,17 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
if optimize_model:
|
if optimize_model:
|
||||||
model = _optimize_pre(model)
|
model = _optimize_pre(model)
|
||||||
|
|
||||||
|
# mixed quantization needs model_type to choose custom quantization strategy
|
||||||
|
if hasattr(model, "config"):
|
||||||
|
model_type = getattr(model.config, "model_type", None)
|
||||||
|
else:
|
||||||
|
model_type = None
|
||||||
model, has_been_replaced = _replace_with_low_bit_linear(
|
model, has_been_replaced = _replace_with_low_bit_linear(
|
||||||
model, qtype, modules_to_not_convert,
|
model, qtype, modules_to_not_convert,
|
||||||
None, convert_shape_only, cpu_embedding,
|
None, convert_shape_only, cpu_embedding,
|
||||||
imatrix_data=imatrix_data,
|
imatrix_data=imatrix_data,
|
||||||
embedding_qtype=embedding_qtype
|
embedding_qtype=embedding_qtype,
|
||||||
|
model_type=model_type
|
||||||
)
|
)
|
||||||
if not has_been_replaced:
|
if not has_been_replaced:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
||||||
|
|
@ -197,16 +197,24 @@ def load_imatrix_data(imatrix_file):
|
||||||
cur_len = imatrix.read(4)
|
cur_len = imatrix.read(4)
|
||||||
cur_len = int.from_bytes(cur_len, 'little')
|
cur_len = int.from_bytes(cur_len, 'little')
|
||||||
cur_name = str(imatrix.read(cur_len), encoding='utf-8')
|
cur_name = str(imatrix.read(cur_len), encoding='utf-8')
|
||||||
# original cur_name looks like blk.14.attn_output.weight for llama
|
# cur_name looks like blk.14.attn_output.weight for llama / mistral,
|
||||||
# TODO: how to better aligned and generalize
|
# cur_name looks like blk.0.ffn_down.3.weight for mixtral and
|
||||||
|
# blk.17.ffn_gate_inp.weight for mixtral
|
||||||
name_list = cur_name.split('.')
|
name_list = cur_name.split('.')
|
||||||
layer = name_list[1]
|
layer = name_list[1]
|
||||||
module_name = name_list[2]
|
module_name = name_list[2]
|
||||||
if 'ffn' in module_name:
|
exp_id = None
|
||||||
|
if 'ffn' in module_name and len(name_list) == 4:
|
||||||
module_name = module_name[4:] # from ffn_gate to gate
|
module_name = module_name[4:] # from ffn_gate to gate
|
||||||
|
elif 'ffn' in module_name and len(name_list) == 5:
|
||||||
|
# mixtral's mlp layer
|
||||||
|
module_name = module_name[4:]
|
||||||
|
exp_id = name_list[3]
|
||||||
elif 'attn' in module_name:
|
elif 'attn' in module_name:
|
||||||
module_name = module_name[5] # from attn_k to k, attn_output to o
|
module_name = module_name[5] # from attn_k to k, attn_output to o
|
||||||
module_name = layer + '_' + module_name
|
module_name = layer + '_' + module_name
|
||||||
|
if exp_id is not None:
|
||||||
|
module_name += '_' + exp_id
|
||||||
ncall = imatrix.read(4)
|
ncall = imatrix.read(4)
|
||||||
ncall = int.from_bytes(ncall, 'little')
|
ncall = int.from_bytes(ncall, 'little')
|
||||||
nval = imatrix.read(4)
|
nval = imatrix.read(4)
|
||||||
|
|
@ -225,32 +233,59 @@ def load_imatrix_data(imatrix_file):
|
||||||
|
|
||||||
|
|
||||||
def module_name_process(full_module_name):
|
def module_name_process(full_module_name):
|
||||||
# full name maybe model.layers.31.self_attn.o_proj
|
# full name maybe model.layers.31.self_attn.o_proj for llama/mistral
|
||||||
# TODO: how to better aligned and generalize
|
# full name maybe model.layers.0.block_sparse_moe.gate or
|
||||||
module_name = full_module_name.split('.')
|
# model.layers.0.block_sparse_moe.experts.0.w1 for mixtral
|
||||||
if len(module_name) == 5:
|
module_name_list = full_module_name.split('.')
|
||||||
layer = module_name[2]
|
if len(module_name_list) >= 5:
|
||||||
cur_module = module_name[-1][:-5]
|
super_module_name = module_name_list[3]
|
||||||
|
else:
|
||||||
|
super_module_name = None
|
||||||
|
exp_id = None
|
||||||
|
if super_module_name == 'block_sparse_moe':
|
||||||
|
# handle mixtral moe here
|
||||||
|
moe_mapping = {"w1": "gate", "w2": "down", "w3": "up"}
|
||||||
|
layer = module_name_list[2]
|
||||||
|
if len(module_name_list) == 5 and module_name_list[-1] == 'gate':
|
||||||
|
cur_module = 'gate_inp' # mapping with imatrix
|
||||||
|
elif len(module_name_list) == 7:
|
||||||
|
exp_id = module_name_list[-2]
|
||||||
|
cur_module = module_name_list[-1]
|
||||||
|
cur_module = moe_mapping[cur_module]
|
||||||
new_module_name = '_'.join([layer, cur_module])
|
new_module_name = '_'.join([layer, cur_module])
|
||||||
elif len(module_name) == 1:
|
if exp_id is not None:
|
||||||
new_module_name = module_name[0]
|
new_module_name += '_' + exp_id
|
||||||
|
else:
|
||||||
|
if len(module_name_list) == 5:
|
||||||
|
layer = module_name_list[2]
|
||||||
|
cur_module = module_name_list[-1][:-5]
|
||||||
|
new_module_name = '_'.join([layer, cur_module])
|
||||||
|
elif len(module_name_list) == 1:
|
||||||
|
new_module_name = module_name_list[0]
|
||||||
layer = None
|
layer = None
|
||||||
cur_module = None
|
cur_module = None
|
||||||
return new_module_name, layer, cur_module
|
return new_module_name, layer, cur_module
|
||||||
|
|
||||||
|
|
||||||
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
|
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_type=None):
|
||||||
cur_qtype = qtype
|
cur_qtype = qtype
|
||||||
if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]:
|
if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]:
|
||||||
# For quantization which needs importance matrix
|
# For quantization which needs importance matrix
|
||||||
new_module_name, layer, cur_module = module_name_process(full_module_name)
|
new_module_name, layer, cur_module = module_name_process(full_module_name)
|
||||||
# custom mixed quantization strategy
|
# custom mixed quantization strategy
|
||||||
|
if model_type == "mixtral":
|
||||||
|
if cur_module == 'v':
|
||||||
|
# llama.cpp use q4_K here
|
||||||
|
cur_qtype = ggml_tensor_qtype['sym_int4']
|
||||||
|
elif cur_module == 'down' and int(layer) in [0, 1, 2, 3]:
|
||||||
|
cur_qtype = ggml_tensor_qtype['q2_k']
|
||||||
|
else:
|
||||||
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
|
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
|
||||||
cur_qtype = ggml_tensor_qtype['q2_k']
|
cur_qtype = ggml_tensor_qtype['q2_k']
|
||||||
if imatrix_data is not None and new_module_name in imatrix_data:
|
if imatrix_data is not None and new_module_name in imatrix_data:
|
||||||
cur_imatrix = imatrix_data[new_module_name]
|
cur_imatrix = imatrix_data[new_module_name]
|
||||||
else:
|
else:
|
||||||
# if no imatrix is available, use fp8 for lm_head
|
# if no imatrix is available, use sym_int8 for lm_head
|
||||||
cur_imatrix = None
|
cur_imatrix = None
|
||||||
if new_module_name == 'lm_head':
|
if new_module_name == 'lm_head':
|
||||||
cur_qtype = ggml_tensor_qtype['sym_int8']
|
cur_qtype = ggml_tensor_qtype['sym_int8']
|
||||||
|
|
@ -263,7 +298,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
|
||||||
if imatrix_data is not None and new_module_name in imatrix_data:
|
if imatrix_data is not None and new_module_name in imatrix_data:
|
||||||
cur_imatrix = imatrix_data[new_module_name]
|
cur_imatrix = imatrix_data[new_module_name]
|
||||||
else:
|
else:
|
||||||
# if no imatrix is available, use fp8 for lm_head
|
# if no imatrix is available, use sym_int8 for lm_head
|
||||||
cur_imatrix = None
|
cur_imatrix = None
|
||||||
if new_module_name == 'lm_head':
|
if new_module_name == 'lm_head':
|
||||||
cur_qtype = ggml_tensor_qtype['sym_int8']
|
cur_qtype = ggml_tensor_qtype['sym_int8']
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue