LLM: support iq2 for mixtral (#10191)
* support name mapping for mixtral * support mixtral mixed quantization * fix style * fix
This commit is contained in:
		
							parent
							
								
									079f2011ea
								
							
						
					
					
						commit
						f7c96b19ef
					
				
					 2 changed files with 66 additions and 22 deletions
				
			
		| 
						 | 
					@ -191,7 +191,8 @@ def convert_gptq(module, awq=False, llm_awq=False):
 | 
				
			||||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
					def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                                 current_key_name=None, convert_shape_only=False,
 | 
					                                 current_key_name=None, convert_shape_only=False,
 | 
				
			||||||
                                 cpu_embedding=False, prefix_name='',
 | 
					                                 cpu_embedding=False, prefix_name='',
 | 
				
			||||||
                                 imatrix_data=None, embedding_qtype=None):
 | 
					                                 imatrix_data=None, embedding_qtype=None,
 | 
				
			||||||
 | 
					                                 model_type=None):
 | 
				
			||||||
    from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
 | 
					    from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
 | 
				
			||||||
        FP16Linear, BF16Linear
 | 
					        FP16Linear, BF16Linear
 | 
				
			||||||
    from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
 | 
					    from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
 | 
				
			||||||
| 
						 | 
					@ -251,7 +252,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
                        cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
 | 
					                        cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
 | 
				
			||||||
                                                                           full_module_name,
 | 
					                                                                           full_module_name,
 | 
				
			||||||
                                                                           imatrix_data)
 | 
					                                                                           imatrix_data,
 | 
				
			||||||
 | 
					                                                                           model_type)
 | 
				
			||||||
                        device = module.weight.data.device
 | 
					                        device = module.weight.data.device
 | 
				
			||||||
                        # Copy the weights
 | 
					                        # Copy the weights
 | 
				
			||||||
                        paramsLowBit = FP4Params(data=module.weight.data,
 | 
					                        paramsLowBit = FP4Params(data=module.weight.data,
 | 
				
			||||||
| 
						 | 
					@ -361,7 +363,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                cpu_embedding,
 | 
					                cpu_embedding,
 | 
				
			||||||
                prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
 | 
					                prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
 | 
				
			||||||
                imatrix_data=imatrix_data,
 | 
					                imatrix_data=imatrix_data,
 | 
				
			||||||
                embedding_qtype=embedding_qtype
 | 
					                embedding_qtype=embedding_qtype,
 | 
				
			||||||
 | 
					                model_type=model_type
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            has_been_replaced = _flag or has_been_replaced
 | 
					            has_been_replaced = _flag or has_been_replaced
 | 
				
			||||||
    return model, has_been_replaced
 | 
					    return model, has_been_replaced
 | 
				
			||||||
| 
						 | 
					@ -558,11 +561,17 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
				
			||||||
    if optimize_model:
 | 
					    if optimize_model:
 | 
				
			||||||
        model = _optimize_pre(model)
 | 
					        model = _optimize_pre(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # mixed quantization needs model_type to choose custom quantization strategy
 | 
				
			||||||
 | 
					    if hasattr(model, "config"):
 | 
				
			||||||
 | 
					        model_type = getattr(model.config, "model_type", None)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        model_type = None
 | 
				
			||||||
    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
					    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
				
			||||||
        model, qtype, modules_to_not_convert,
 | 
					        model, qtype, modules_to_not_convert,
 | 
				
			||||||
        None, convert_shape_only, cpu_embedding,
 | 
					        None, convert_shape_only, cpu_embedding,
 | 
				
			||||||
        imatrix_data=imatrix_data,
 | 
					        imatrix_data=imatrix_data,
 | 
				
			||||||
        embedding_qtype=embedding_qtype
 | 
					        embedding_qtype=embedding_qtype,
 | 
				
			||||||
 | 
					        model_type=model_type
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    if not has_been_replaced:
 | 
					    if not has_been_replaced:
 | 
				
			||||||
        warnings.warn(
 | 
					        warnings.warn(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -197,16 +197,24 @@ def load_imatrix_data(imatrix_file):
 | 
				
			||||||
        cur_len = imatrix.read(4)
 | 
					        cur_len = imatrix.read(4)
 | 
				
			||||||
        cur_len = int.from_bytes(cur_len, 'little')
 | 
					        cur_len = int.from_bytes(cur_len, 'little')
 | 
				
			||||||
        cur_name = str(imatrix.read(cur_len), encoding='utf-8')
 | 
					        cur_name = str(imatrix.read(cur_len), encoding='utf-8')
 | 
				
			||||||
        # original cur_name looks like blk.14.attn_output.weight for llama
 | 
					        # cur_name looks like blk.14.attn_output.weight for llama / mistral,
 | 
				
			||||||
        # TODO: how to better aligned and generalize
 | 
					        # cur_name looks like blk.0.ffn_down.3.weight for mixtral and
 | 
				
			||||||
 | 
					        # blk.17.ffn_gate_inp.weight for mixtral
 | 
				
			||||||
        name_list = cur_name.split('.')
 | 
					        name_list = cur_name.split('.')
 | 
				
			||||||
        layer = name_list[1]
 | 
					        layer = name_list[1]
 | 
				
			||||||
        module_name = name_list[2]
 | 
					        module_name = name_list[2]
 | 
				
			||||||
        if 'ffn' in module_name:
 | 
					        exp_id = None
 | 
				
			||||||
 | 
					        if 'ffn' in module_name and len(name_list) == 4:
 | 
				
			||||||
            module_name = module_name[4:]  # from ffn_gate to gate
 | 
					            module_name = module_name[4:]  # from ffn_gate to gate
 | 
				
			||||||
 | 
					        elif 'ffn' in module_name and len(name_list) == 5:
 | 
				
			||||||
 | 
					            # mixtral's mlp layer
 | 
				
			||||||
 | 
					            module_name = module_name[4:]
 | 
				
			||||||
 | 
					            exp_id = name_list[3]
 | 
				
			||||||
        elif 'attn' in module_name:
 | 
					        elif 'attn' in module_name:
 | 
				
			||||||
            module_name = module_name[5]  # from attn_k to k, attn_output to o
 | 
					            module_name = module_name[5]  # from attn_k to k, attn_output to o
 | 
				
			||||||
        module_name = layer + '_' + module_name
 | 
					        module_name = layer + '_' + module_name
 | 
				
			||||||
 | 
					        if exp_id is not None:
 | 
				
			||||||
 | 
					            module_name += '_' + exp_id
 | 
				
			||||||
        ncall = imatrix.read(4)
 | 
					        ncall = imatrix.read(4)
 | 
				
			||||||
        ncall = int.from_bytes(ncall, 'little')
 | 
					        ncall = int.from_bytes(ncall, 'little')
 | 
				
			||||||
        nval = imatrix.read(4)
 | 
					        nval = imatrix.read(4)
 | 
				
			||||||
| 
						 | 
					@ -225,32 +233,59 @@ def load_imatrix_data(imatrix_file):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def module_name_process(full_module_name):
 | 
					def module_name_process(full_module_name):
 | 
				
			||||||
    # full name maybe model.layers.31.self_attn.o_proj
 | 
					    # full name maybe model.layers.31.self_attn.o_proj for llama/mistral
 | 
				
			||||||
    # TODO: how to better aligned and generalize
 | 
					    # full name maybe model.layers.0.block_sparse_moe.gate or
 | 
				
			||||||
    module_name = full_module_name.split('.')
 | 
					    # model.layers.0.block_sparse_moe.experts.0.w1 for mixtral
 | 
				
			||||||
    if len(module_name) == 5:
 | 
					    module_name_list = full_module_name.split('.')
 | 
				
			||||||
        layer = module_name[2]
 | 
					    if len(module_name_list) >= 5:
 | 
				
			||||||
        cur_module = module_name[-1][:-5]
 | 
					        super_module_name = module_name_list[3]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        super_module_name = None
 | 
				
			||||||
 | 
					    exp_id = None
 | 
				
			||||||
 | 
					    if super_module_name == 'block_sparse_moe':
 | 
				
			||||||
 | 
					        # handle mixtral moe here
 | 
				
			||||||
 | 
					        moe_mapping = {"w1": "gate", "w2": "down", "w3": "up"}
 | 
				
			||||||
 | 
					        layer = module_name_list[2]
 | 
				
			||||||
 | 
					        if len(module_name_list) == 5 and module_name_list[-1] == 'gate':
 | 
				
			||||||
 | 
					            cur_module = 'gate_inp'  # mapping with imatrix
 | 
				
			||||||
 | 
					        elif len(module_name_list) == 7:
 | 
				
			||||||
 | 
					            exp_id = module_name_list[-2]
 | 
				
			||||||
 | 
					            cur_module = module_name_list[-1]
 | 
				
			||||||
 | 
					            cur_module = moe_mapping[cur_module]
 | 
				
			||||||
        new_module_name = '_'.join([layer, cur_module])
 | 
					        new_module_name = '_'.join([layer, cur_module])
 | 
				
			||||||
    elif len(module_name) == 1:
 | 
					        if exp_id is not None:
 | 
				
			||||||
        new_module_name = module_name[0]
 | 
					            new_module_name += '_' + exp_id
 | 
				
			||||||
        layer = None
 | 
					    else:
 | 
				
			||||||
        cur_module = None
 | 
					        if len(module_name_list) == 5:
 | 
				
			||||||
 | 
					            layer = module_name_list[2]
 | 
				
			||||||
 | 
					            cur_module = module_name_list[-1][:-5]
 | 
				
			||||||
 | 
					            new_module_name = '_'.join([layer, cur_module])
 | 
				
			||||||
 | 
					        elif len(module_name_list) == 1:
 | 
				
			||||||
 | 
					            new_module_name = module_name_list[0]
 | 
				
			||||||
 | 
					            layer = None
 | 
				
			||||||
 | 
					            cur_module = None
 | 
				
			||||||
    return new_module_name, layer, cur_module
 | 
					    return new_module_name, layer, cur_module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
 | 
					def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_type=None):
 | 
				
			||||||
    cur_qtype = qtype
 | 
					    cur_qtype = qtype
 | 
				
			||||||
    if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]:
 | 
					    if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]:
 | 
				
			||||||
        # For quantization which needs importance matrix
 | 
					        # For quantization which needs importance matrix
 | 
				
			||||||
        new_module_name, layer, cur_module = module_name_process(full_module_name)
 | 
					        new_module_name, layer, cur_module = module_name_process(full_module_name)
 | 
				
			||||||
        # custom mixed quantization strategy
 | 
					        # custom mixed quantization strategy
 | 
				
			||||||
        if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
 | 
					        if model_type == "mixtral":
 | 
				
			||||||
            cur_qtype = ggml_tensor_qtype['q2_k']
 | 
					            if cur_module == 'v':
 | 
				
			||||||
 | 
					                # llama.cpp use q4_K here
 | 
				
			||||||
 | 
					                cur_qtype = ggml_tensor_qtype['sym_int4']
 | 
				
			||||||
 | 
					            elif cur_module == 'down' and int(layer) in [0, 1, 2, 3]:
 | 
				
			||||||
 | 
					                cur_qtype = ggml_tensor_qtype['q2_k']
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
 | 
				
			||||||
 | 
					                cur_qtype = ggml_tensor_qtype['q2_k']
 | 
				
			||||||
        if imatrix_data is not None and new_module_name in imatrix_data:
 | 
					        if imatrix_data is not None and new_module_name in imatrix_data:
 | 
				
			||||||
            cur_imatrix = imatrix_data[new_module_name]
 | 
					            cur_imatrix = imatrix_data[new_module_name]
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            # if no imatrix is available, use fp8 for lm_head
 | 
					            # if no imatrix is available, use sym_int8 for lm_head
 | 
				
			||||||
            cur_imatrix = None
 | 
					            cur_imatrix = None
 | 
				
			||||||
            if new_module_name == 'lm_head':
 | 
					            if new_module_name == 'lm_head':
 | 
				
			||||||
                cur_qtype = ggml_tensor_qtype['sym_int8']
 | 
					                cur_qtype = ggml_tensor_qtype['sym_int8']
 | 
				
			||||||
| 
						 | 
					@ -263,7 +298,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
 | 
				
			||||||
        if imatrix_data is not None and new_module_name in imatrix_data:
 | 
					        if imatrix_data is not None and new_module_name in imatrix_data:
 | 
				
			||||||
            cur_imatrix = imatrix_data[new_module_name]
 | 
					            cur_imatrix = imatrix_data[new_module_name]
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            # if no imatrix is available, use fp8 for lm_head
 | 
					            # if no imatrix is available, use sym_int8 for lm_head
 | 
				
			||||||
            cur_imatrix = None
 | 
					            cur_imatrix = None
 | 
				
			||||||
            if new_module_name == 'lm_head':
 | 
					            if new_module_name == 'lm_head':
 | 
				
			||||||
                cur_qtype = ggml_tensor_qtype['sym_int8']
 | 
					                cur_qtype = ggml_tensor_qtype['sym_int8']
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue