diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 3eca37de..a5fd8813 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -191,7 +191,8 @@ def convert_gptq(module, awq=False, llm_awq=False): def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, current_key_name=None, convert_shape_only=False, cpu_embedding=False, prefix_name='', - imatrix_data=None, embedding_qtype=None): + imatrix_data=None, embedding_qtype=None, + model_type=None): from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \ FP16Linear, BF16Linear from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding @@ -251,7 +252,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, ) cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype, full_module_name, - imatrix_data) + imatrix_data, + model_type) device = module.weight.data.device # Copy the weights paramsLowBit = FP4Params(data=module.weight.data, @@ -361,7 +363,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, cpu_embedding, prefix_name=prefix_name + '.' + name if prefix_name != '' else name, imatrix_data=imatrix_data, - embedding_qtype=embedding_qtype + embedding_qtype=embedding_qtype, + model_type=model_type ) has_been_replaced = _flag or has_been_replaced return model, has_been_replaced @@ -558,11 +561,17 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, if optimize_model: model = _optimize_pre(model) + # mixed quantization needs model_type to choose custom quantization strategy + if hasattr(model, "config"): + model_type = getattr(model.config, "model_type", None) + else: + model_type = None model, has_been_replaced = _replace_with_low_bit_linear( model, qtype, modules_to_not_convert, None, convert_shape_only, cpu_embedding, imatrix_data=imatrix_data, - embedding_qtype=embedding_qtype + embedding_qtype=embedding_qtype, + model_type=model_type ) if not has_been_replaced: warnings.warn( diff --git a/python/llm/src/bigdl/llm/transformers/utils.py b/python/llm/src/bigdl/llm/transformers/utils.py index 7f85f03b..467bcaf6 100644 --- a/python/llm/src/bigdl/llm/transformers/utils.py +++ b/python/llm/src/bigdl/llm/transformers/utils.py @@ -197,16 +197,24 @@ def load_imatrix_data(imatrix_file): cur_len = imatrix.read(4) cur_len = int.from_bytes(cur_len, 'little') cur_name = str(imatrix.read(cur_len), encoding='utf-8') - # original cur_name looks like blk.14.attn_output.weight for llama - # TODO: how to better aligned and generalize + # cur_name looks like blk.14.attn_output.weight for llama / mistral, + # cur_name looks like blk.0.ffn_down.3.weight for mixtral and + # blk.17.ffn_gate_inp.weight for mixtral name_list = cur_name.split('.') layer = name_list[1] module_name = name_list[2] - if 'ffn' in module_name: + exp_id = None + if 'ffn' in module_name and len(name_list) == 4: module_name = module_name[4:] # from ffn_gate to gate + elif 'ffn' in module_name and len(name_list) == 5: + # mixtral's mlp layer + module_name = module_name[4:] + exp_id = name_list[3] elif 'attn' in module_name: module_name = module_name[5] # from attn_k to k, attn_output to o module_name = layer + '_' + module_name + if exp_id is not None: + module_name += '_' + exp_id ncall = imatrix.read(4) ncall = int.from_bytes(ncall, 'little') nval = imatrix.read(4) @@ -225,32 +233,59 @@ def load_imatrix_data(imatrix_file): def module_name_process(full_module_name): - # full name maybe model.layers.31.self_attn.o_proj - # TODO: how to better aligned and generalize - module_name = full_module_name.split('.') - if len(module_name) == 5: - layer = module_name[2] - cur_module = module_name[-1][:-5] + # full name maybe model.layers.31.self_attn.o_proj for llama/mistral + # full name maybe model.layers.0.block_sparse_moe.gate or + # model.layers.0.block_sparse_moe.experts.0.w1 for mixtral + module_name_list = full_module_name.split('.') + if len(module_name_list) >= 5: + super_module_name = module_name_list[3] + else: + super_module_name = None + exp_id = None + if super_module_name == 'block_sparse_moe': + # handle mixtral moe here + moe_mapping = {"w1": "gate", "w2": "down", "w3": "up"} + layer = module_name_list[2] + if len(module_name_list) == 5 and module_name_list[-1] == 'gate': + cur_module = 'gate_inp' # mapping with imatrix + elif len(module_name_list) == 7: + exp_id = module_name_list[-2] + cur_module = module_name_list[-1] + cur_module = moe_mapping[cur_module] new_module_name = '_'.join([layer, cur_module]) - elif len(module_name) == 1: - new_module_name = module_name[0] - layer = None - cur_module = None + if exp_id is not None: + new_module_name += '_' + exp_id + else: + if len(module_name_list) == 5: + layer = module_name_list[2] + cur_module = module_name_list[-1][:-5] + new_module_name = '_'.join([layer, cur_module]) + elif len(module_name_list) == 1: + new_module_name = module_name_list[0] + layer = None + cur_module = None return new_module_name, layer, cur_module -def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data): +def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_type=None): cur_qtype = qtype if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]: # For quantization which needs importance matrix new_module_name, layer, cur_module = module_name_process(full_module_name) # custom mixed quantization strategy - if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]): - cur_qtype = ggml_tensor_qtype['q2_k'] + if model_type == "mixtral": + if cur_module == 'v': + # llama.cpp use q4_K here + cur_qtype = ggml_tensor_qtype['sym_int4'] + elif cur_module == 'down' and int(layer) in [0, 1, 2, 3]: + cur_qtype = ggml_tensor_qtype['q2_k'] + else: + if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]): + cur_qtype = ggml_tensor_qtype['q2_k'] if imatrix_data is not None and new_module_name in imatrix_data: cur_imatrix = imatrix_data[new_module_name] else: - # if no imatrix is available, use fp8 for lm_head + # if no imatrix is available, use sym_int8 for lm_head cur_imatrix = None if new_module_name == 'lm_head': cur_qtype = ggml_tensor_qtype['sym_int8'] @@ -263,7 +298,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data): if imatrix_data is not None and new_module_name in imatrix_data: cur_imatrix = imatrix_data[new_module_name] else: - # if no imatrix is available, use fp8 for lm_head + # if no imatrix is available, use sym_int8 for lm_head cur_imatrix = None if new_module_name == 'lm_head': cur_qtype = ggml_tensor_qtype['sym_int8']