LLM: support iq2 for mixtral (#10191)
* support name mapping for mixtral * support mixtral mixed quantization * fix style * fix
This commit is contained in:
		
							parent
							
								
									079f2011ea
								
							
						
					
					
						commit
						f7c96b19ef
					
				
					 2 changed files with 66 additions and 22 deletions
				
			
		| 
						 | 
				
			
			@ -191,7 +191,8 @@ def convert_gptq(module, awq=False, llm_awq=False):
 | 
			
		|||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		||||
                                 current_key_name=None, convert_shape_only=False,
 | 
			
		||||
                                 cpu_embedding=False, prefix_name='',
 | 
			
		||||
                                 imatrix_data=None, embedding_qtype=None):
 | 
			
		||||
                                 imatrix_data=None, embedding_qtype=None,
 | 
			
		||||
                                 model_type=None):
 | 
			
		||||
    from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
 | 
			
		||||
        FP16Linear, BF16Linear
 | 
			
		||||
    from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
 | 
			
		||||
| 
						 | 
				
			
			@ -251,7 +252,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                        )
 | 
			
		||||
                        cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
 | 
			
		||||
                                                                           full_module_name,
 | 
			
		||||
                                                                           imatrix_data)
 | 
			
		||||
                                                                           imatrix_data,
 | 
			
		||||
                                                                           model_type)
 | 
			
		||||
                        device = module.weight.data.device
 | 
			
		||||
                        # Copy the weights
 | 
			
		||||
                        paramsLowBit = FP4Params(data=module.weight.data,
 | 
			
		||||
| 
						 | 
				
			
			@ -361,7 +363,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                cpu_embedding,
 | 
			
		||||
                prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
 | 
			
		||||
                imatrix_data=imatrix_data,
 | 
			
		||||
                embedding_qtype=embedding_qtype
 | 
			
		||||
                embedding_qtype=embedding_qtype,
 | 
			
		||||
                model_type=model_type
 | 
			
		||||
            )
 | 
			
		||||
            has_been_replaced = _flag or has_been_replaced
 | 
			
		||||
    return model, has_been_replaced
 | 
			
		||||
| 
						 | 
				
			
			@ -558,11 +561,17 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
    if optimize_model:
 | 
			
		||||
        model = _optimize_pre(model)
 | 
			
		||||
 | 
			
		||||
    # mixed quantization needs model_type to choose custom quantization strategy
 | 
			
		||||
    if hasattr(model, "config"):
 | 
			
		||||
        model_type = getattr(model.config, "model_type", None)
 | 
			
		||||
    else:
 | 
			
		||||
        model_type = None
 | 
			
		||||
    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
			
		||||
        model, qtype, modules_to_not_convert,
 | 
			
		||||
        None, convert_shape_only, cpu_embedding,
 | 
			
		||||
        imatrix_data=imatrix_data,
 | 
			
		||||
        embedding_qtype=embedding_qtype
 | 
			
		||||
        embedding_qtype=embedding_qtype,
 | 
			
		||||
        model_type=model_type
 | 
			
		||||
    )
 | 
			
		||||
    if not has_been_replaced:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -197,16 +197,24 @@ def load_imatrix_data(imatrix_file):
 | 
			
		|||
        cur_len = imatrix.read(4)
 | 
			
		||||
        cur_len = int.from_bytes(cur_len, 'little')
 | 
			
		||||
        cur_name = str(imatrix.read(cur_len), encoding='utf-8')
 | 
			
		||||
        # original cur_name looks like blk.14.attn_output.weight for llama
 | 
			
		||||
        # TODO: how to better aligned and generalize
 | 
			
		||||
        # cur_name looks like blk.14.attn_output.weight for llama / mistral,
 | 
			
		||||
        # cur_name looks like blk.0.ffn_down.3.weight for mixtral and
 | 
			
		||||
        # blk.17.ffn_gate_inp.weight for mixtral
 | 
			
		||||
        name_list = cur_name.split('.')
 | 
			
		||||
        layer = name_list[1]
 | 
			
		||||
        module_name = name_list[2]
 | 
			
		||||
        if 'ffn' in module_name:
 | 
			
		||||
        exp_id = None
 | 
			
		||||
        if 'ffn' in module_name and len(name_list) == 4:
 | 
			
		||||
            module_name = module_name[4:]  # from ffn_gate to gate
 | 
			
		||||
        elif 'ffn' in module_name and len(name_list) == 5:
 | 
			
		||||
            # mixtral's mlp layer
 | 
			
		||||
            module_name = module_name[4:]
 | 
			
		||||
            exp_id = name_list[3]
 | 
			
		||||
        elif 'attn' in module_name:
 | 
			
		||||
            module_name = module_name[5]  # from attn_k to k, attn_output to o
 | 
			
		||||
        module_name = layer + '_' + module_name
 | 
			
		||||
        if exp_id is not None:
 | 
			
		||||
            module_name += '_' + exp_id
 | 
			
		||||
        ncall = imatrix.read(4)
 | 
			
		||||
        ncall = int.from_bytes(ncall, 'little')
 | 
			
		||||
        nval = imatrix.read(4)
 | 
			
		||||
| 
						 | 
				
			
			@ -225,32 +233,59 @@ def load_imatrix_data(imatrix_file):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def module_name_process(full_module_name):
 | 
			
		||||
    # full name maybe model.layers.31.self_attn.o_proj
 | 
			
		||||
    # TODO: how to better aligned and generalize
 | 
			
		||||
    module_name = full_module_name.split('.')
 | 
			
		||||
    if len(module_name) == 5:
 | 
			
		||||
        layer = module_name[2]
 | 
			
		||||
        cur_module = module_name[-1][:-5]
 | 
			
		||||
    # full name maybe model.layers.31.self_attn.o_proj for llama/mistral
 | 
			
		||||
    # full name maybe model.layers.0.block_sparse_moe.gate or
 | 
			
		||||
    # model.layers.0.block_sparse_moe.experts.0.w1 for mixtral
 | 
			
		||||
    module_name_list = full_module_name.split('.')
 | 
			
		||||
    if len(module_name_list) >= 5:
 | 
			
		||||
        super_module_name = module_name_list[3]
 | 
			
		||||
    else:
 | 
			
		||||
        super_module_name = None
 | 
			
		||||
    exp_id = None
 | 
			
		||||
    if super_module_name == 'block_sparse_moe':
 | 
			
		||||
        # handle mixtral moe here
 | 
			
		||||
        moe_mapping = {"w1": "gate", "w2": "down", "w3": "up"}
 | 
			
		||||
        layer = module_name_list[2]
 | 
			
		||||
        if len(module_name_list) == 5 and module_name_list[-1] == 'gate':
 | 
			
		||||
            cur_module = 'gate_inp'  # mapping with imatrix
 | 
			
		||||
        elif len(module_name_list) == 7:
 | 
			
		||||
            exp_id = module_name_list[-2]
 | 
			
		||||
            cur_module = module_name_list[-1]
 | 
			
		||||
            cur_module = moe_mapping[cur_module]
 | 
			
		||||
        new_module_name = '_'.join([layer, cur_module])
 | 
			
		||||
    elif len(module_name) == 1:
 | 
			
		||||
        new_module_name = module_name[0]
 | 
			
		||||
        if exp_id is not None:
 | 
			
		||||
            new_module_name += '_' + exp_id
 | 
			
		||||
    else:
 | 
			
		||||
        if len(module_name_list) == 5:
 | 
			
		||||
            layer = module_name_list[2]
 | 
			
		||||
            cur_module = module_name_list[-1][:-5]
 | 
			
		||||
            new_module_name = '_'.join([layer, cur_module])
 | 
			
		||||
        elif len(module_name_list) == 1:
 | 
			
		||||
            new_module_name = module_name_list[0]
 | 
			
		||||
            layer = None
 | 
			
		||||
            cur_module = None
 | 
			
		||||
    return new_module_name, layer, cur_module
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
 | 
			
		||||
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_type=None):
 | 
			
		||||
    cur_qtype = qtype
 | 
			
		||||
    if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]:
 | 
			
		||||
        # For quantization which needs importance matrix
 | 
			
		||||
        new_module_name, layer, cur_module = module_name_process(full_module_name)
 | 
			
		||||
        # custom mixed quantization strategy
 | 
			
		||||
        if model_type == "mixtral":
 | 
			
		||||
            if cur_module == 'v':
 | 
			
		||||
                # llama.cpp use q4_K here
 | 
			
		||||
                cur_qtype = ggml_tensor_qtype['sym_int4']
 | 
			
		||||
            elif cur_module == 'down' and int(layer) in [0, 1, 2, 3]:
 | 
			
		||||
                cur_qtype = ggml_tensor_qtype['q2_k']
 | 
			
		||||
        else:
 | 
			
		||||
            if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
 | 
			
		||||
                cur_qtype = ggml_tensor_qtype['q2_k']
 | 
			
		||||
        if imatrix_data is not None and new_module_name in imatrix_data:
 | 
			
		||||
            cur_imatrix = imatrix_data[new_module_name]
 | 
			
		||||
        else:
 | 
			
		||||
            # if no imatrix is available, use fp8 for lm_head
 | 
			
		||||
            # if no imatrix is available, use sym_int8 for lm_head
 | 
			
		||||
            cur_imatrix = None
 | 
			
		||||
            if new_module_name == 'lm_head':
 | 
			
		||||
                cur_qtype = ggml_tensor_qtype['sym_int8']
 | 
			
		||||
| 
						 | 
				
			
			@ -263,7 +298,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
 | 
			
		|||
        if imatrix_data is not None and new_module_name in imatrix_data:
 | 
			
		||||
            cur_imatrix = imatrix_data[new_module_name]
 | 
			
		||||
        else:
 | 
			
		||||
            # if no imatrix is available, use fp8 for lm_head
 | 
			
		||||
            # if no imatrix is available, use sym_int8 for lm_head
 | 
			
		||||
            cur_imatrix = None
 | 
			
		||||
            if new_module_name == 'lm_head':
 | 
			
		||||
                cur_qtype = ggml_tensor_qtype['sym_int8']
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue