Fix modules_not_to_convert argument (#10483)
				
					
				
			This commit is contained in:
		
							parent
							
								
									cbe24cc7e6
								
							
						
					
					
						commit
						cfdf8ad496
					
				
					 1 changed files with 122 additions and 125 deletions
				
			
		| 
						 | 
				
			
			@ -189,7 +189,7 @@ def convert_gptq(module, awq=False, llm_awq=False):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		||||
                                 current_key_name=None, convert_shape_only=False,
 | 
			
		||||
                                 convert_shape_only=False,
 | 
			
		||||
                                 cpu_embedding=False, prefix_name='',
 | 
			
		||||
                                 imatrix_data=None, embedding_qtype=None,
 | 
			
		||||
                                 model_type=None, torch_dtype=torch.float32,
 | 
			
		||||
| 
						 | 
				
			
			@ -200,133 +200,131 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
    has_been_replaced = False
 | 
			
		||||
 | 
			
		||||
    for name, module in model.named_children():
 | 
			
		||||
        if current_key_name is None:
 | 
			
		||||
            current_key_name = []
 | 
			
		||||
 | 
			
		||||
        is_linear, linear_args = is_linear_module(module)
 | 
			
		||||
        full_module_name = prefix_name + '.' + name if prefix_name != '' else name
 | 
			
		||||
        if is_linear and name not in modules_to_not_convert and \
 | 
			
		||||
                full_module_name not in modules_to_not_convert:
 | 
			
		||||
            # Check if the current key is not in the `modules_to_not_convert`
 | 
			
		||||
            if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
 | 
			
		||||
                    not isinstance(module, LowBitLinear)):
 | 
			
		||||
                in_features, out_features, mp_group = linear_args
 | 
			
		||||
                optimize_lm_head = False
 | 
			
		||||
                if name == "lm_head":
 | 
			
		||||
                    if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD",
 | 
			
		||||
                                                                          None) == "1":
 | 
			
		||||
                        optimize_lm_head = True
 | 
			
		||||
                with init_empty_weights():
 | 
			
		||||
                    new_linear = None
 | 
			
		||||
                    is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
 | 
			
		||||
                    is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
 | 
			
		||||
                    is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
 | 
			
		||||
                    if is_gptq or is_awq:
 | 
			
		||||
                        has_bias = module.bias is not None and module.bias.abs().sum() != 0
 | 
			
		||||
                        new_linear = LowBitLinear(
 | 
			
		||||
                            in_features,
 | 
			
		||||
                            out_features,
 | 
			
		||||
                            qtype=qtype,
 | 
			
		||||
                            bias=has_bias,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                            enable_xetla=enable_xetla,
 | 
			
		||||
                            optimize_lm_head=optimize_lm_head
 | 
			
		||||
                        )
 | 
			
		||||
                        device = module.qweight.data.device
 | 
			
		||||
                        invalidInputError(device.type != "meta",
 | 
			
		||||
                                          "converting from meta device is not supported")
 | 
			
		||||
                        # Copy the weights
 | 
			
		||||
                        paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq,
 | 
			
		||||
                                                                   llm_awq=is_llm_awq),
 | 
			
		||||
                                                 requires_grad=False,
 | 
			
		||||
                                                 quantized=True,
 | 
			
		||||
                                                 _shape=(out_features, in_features),
 | 
			
		||||
                                                 convert_shape_only=convert_shape_only,
 | 
			
		||||
                                                 qtype=qtype,
 | 
			
		||||
                                                 enable_xetla=enable_xetla).to(device)
 | 
			
		||||
                        new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                        if has_bias:
 | 
			
		||||
                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                                .to(device)
 | 
			
		||||
                    elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
 | 
			
		||||
                        if in_features % 64 != 0:
 | 
			
		||||
                            # now our kernel requires in_features is a multiple of 64
 | 
			
		||||
                            continue
 | 
			
		||||
                        new_linear = LowBitLinear(
 | 
			
		||||
                            in_features,
 | 
			
		||||
                            out_features,
 | 
			
		||||
                            qtype,
 | 
			
		||||
                            module.bias is not None,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                            enable_xetla=enable_xetla,
 | 
			
		||||
                            optimize_lm_head=optimize_lm_head
 | 
			
		||||
                        )
 | 
			
		||||
                        cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
 | 
			
		||||
                                                                           full_module_name,
 | 
			
		||||
                                                                           imatrix_data,
 | 
			
		||||
                                                                           model_type)
 | 
			
		||||
                        device = module.weight.data.device
 | 
			
		||||
                        # Copy the weights
 | 
			
		||||
                        paramsLowBit = FP4Params(data=module.weight.data,
 | 
			
		||||
                                                 requires_grad=False,
 | 
			
		||||
                                                 quantized=False,
 | 
			
		||||
                                                 _shape=None,
 | 
			
		||||
                                                 convert_shape_only=convert_shape_only,
 | 
			
		||||
                                                 qtype=cur_qtype,
 | 
			
		||||
                                                 imatrix=cur_imatrix,
 | 
			
		||||
                                                 in_features=in_features,
 | 
			
		||||
                                                 enable_xetla=enable_xetla).to(device)
 | 
			
		||||
                        new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                        if module.bias is not None:
 | 
			
		||||
                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                                .to(device)
 | 
			
		||||
                    elif qtype == ggml_tensor_qtype["fp16"]:
 | 
			
		||||
                        module.to(torch.float16)
 | 
			
		||||
                        new_linear = FP16Linear(
 | 
			
		||||
                            in_features,
 | 
			
		||||
                            out_features,
 | 
			
		||||
                            module.bias is not None,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                            optimize_lm_head=optimize_lm_head
 | 
			
		||||
                        )
 | 
			
		||||
                        device = module.weight.data.device
 | 
			
		||||
                        from bigdl.llm.transformers.utils import get_ipex_version
 | 
			
		||||
                        if get_ipex_version() < "2.1.10+xpu":
 | 
			
		||||
                            new_linear._parameters['weight'] = nn.Parameter(module.weight)
 | 
			
		||||
                        else:
 | 
			
		||||
                            # only from 2.1, ipex provides matmul_bias_out
 | 
			
		||||
                            # so we need to transpose weight
 | 
			
		||||
                            new_weight = module.weight.transpose(0, 1).contiguous()
 | 
			
		||||
                            new_linear._parameters['weight'] = nn.Parameter(new_weight)
 | 
			
		||||
                            new_linear.weight_type = 2
 | 
			
		||||
                        if module.bias is not None:
 | 
			
		||||
                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                                .to(device)
 | 
			
		||||
                    elif qtype == ggml_tensor_qtype["bf16"]:
 | 
			
		||||
                        module.to(torch.bfloat16)
 | 
			
		||||
                        new_linear = BF16Linear(
 | 
			
		||||
                            in_features,
 | 
			
		||||
                            out_features,
 | 
			
		||||
                            module.bias is not None,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                            optimize_lm_head=optimize_lm_head
 | 
			
		||||
                        )
 | 
			
		||||
                        device = module.weight.data.device
 | 
			
		||||
                        # convert here
 | 
			
		||||
 | 
			
		||||
        # use sub-string to match, it may match `10` if user only pass a number like `0`
 | 
			
		||||
        if any(key in full_module_name for key in modules_to_not_convert):
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if is_linear and not isinstance(module, LowBitLinear):
 | 
			
		||||
            in_features, out_features, mp_group = linear_args
 | 
			
		||||
            optimize_lm_head = False
 | 
			
		||||
            if name == "lm_head":
 | 
			
		||||
                if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD",
 | 
			
		||||
                                                                      None) == "1":
 | 
			
		||||
                    optimize_lm_head = True
 | 
			
		||||
            with init_empty_weights():
 | 
			
		||||
                new_linear = None
 | 
			
		||||
                is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
 | 
			
		||||
                is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
 | 
			
		||||
                is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
 | 
			
		||||
                if is_gptq or is_awq:
 | 
			
		||||
                    has_bias = module.bias is not None and module.bias.abs().sum() != 0
 | 
			
		||||
                    new_linear = LowBitLinear(
 | 
			
		||||
                        in_features,
 | 
			
		||||
                        out_features,
 | 
			
		||||
                        qtype=qtype,
 | 
			
		||||
                        bias=has_bias,
 | 
			
		||||
                        mp_group=mp_group,
 | 
			
		||||
                        enable_xetla=enable_xetla,
 | 
			
		||||
                        optimize_lm_head=optimize_lm_head
 | 
			
		||||
                    )
 | 
			
		||||
                    device = module.qweight.data.device
 | 
			
		||||
                    invalidInputError(device.type != "meta",
 | 
			
		||||
                                      "converting from meta device is not supported")
 | 
			
		||||
                    # Copy the weights
 | 
			
		||||
                    paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq,
 | 
			
		||||
                                                               llm_awq=is_llm_awq),
 | 
			
		||||
                                             requires_grad=False,
 | 
			
		||||
                                             quantized=True,
 | 
			
		||||
                                             _shape=(out_features, in_features),
 | 
			
		||||
                                             convert_shape_only=convert_shape_only,
 | 
			
		||||
                                             qtype=qtype,
 | 
			
		||||
                                             enable_xetla=enable_xetla).to(device)
 | 
			
		||||
                    new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                    if has_bias:
 | 
			
		||||
                        new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                            .to(device)
 | 
			
		||||
                elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
 | 
			
		||||
                    if in_features % 64 != 0:
 | 
			
		||||
                        # now our kernel requires in_features is a multiple of 64
 | 
			
		||||
                        continue
 | 
			
		||||
                    new_linear = LowBitLinear(
 | 
			
		||||
                        in_features,
 | 
			
		||||
                        out_features,
 | 
			
		||||
                        qtype,
 | 
			
		||||
                        module.bias is not None,
 | 
			
		||||
                        mp_group=mp_group,
 | 
			
		||||
                        enable_xetla=enable_xetla,
 | 
			
		||||
                        optimize_lm_head=optimize_lm_head
 | 
			
		||||
                    )
 | 
			
		||||
                    cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
 | 
			
		||||
                                                                       full_module_name,
 | 
			
		||||
                                                                       imatrix_data,
 | 
			
		||||
                                                                       model_type)
 | 
			
		||||
                    device = module.weight.data.device
 | 
			
		||||
                    # Copy the weights
 | 
			
		||||
                    paramsLowBit = FP4Params(data=module.weight.data,
 | 
			
		||||
                                             requires_grad=False,
 | 
			
		||||
                                             quantized=False,
 | 
			
		||||
                                             _shape=None,
 | 
			
		||||
                                             convert_shape_only=convert_shape_only,
 | 
			
		||||
                                             qtype=cur_qtype,
 | 
			
		||||
                                             imatrix=cur_imatrix,
 | 
			
		||||
                                             in_features=in_features,
 | 
			
		||||
                                             enable_xetla=enable_xetla).to(device)
 | 
			
		||||
                    new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                    if module.bias is not None:
 | 
			
		||||
                        new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                            .to(device)
 | 
			
		||||
                elif qtype == ggml_tensor_qtype["fp16"]:
 | 
			
		||||
                    module.to(torch.float16)
 | 
			
		||||
                    new_linear = FP16Linear(
 | 
			
		||||
                        in_features,
 | 
			
		||||
                        out_features,
 | 
			
		||||
                        module.bias is not None,
 | 
			
		||||
                        mp_group=mp_group,
 | 
			
		||||
                        optimize_lm_head=optimize_lm_head
 | 
			
		||||
                    )
 | 
			
		||||
                    device = module.weight.data.device
 | 
			
		||||
                    from bigdl.llm.transformers.utils import get_ipex_version
 | 
			
		||||
                    if get_ipex_version() < "2.1.10+xpu":
 | 
			
		||||
                        new_linear._parameters['weight'] = nn.Parameter(module.weight)
 | 
			
		||||
                        if module.bias is not None:
 | 
			
		||||
                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                                .to(device)
 | 
			
		||||
                    else:
 | 
			
		||||
                        # only from 2.1, ipex provides matmul_bias_out
 | 
			
		||||
                        # so we need to transpose weight
 | 
			
		||||
                        new_weight = module.weight.transpose(0, 1).contiguous()
 | 
			
		||||
                        new_linear._parameters['weight'] = nn.Parameter(new_weight)
 | 
			
		||||
                        new_linear.weight_type = 2
 | 
			
		||||
                    if module.bias is not None:
 | 
			
		||||
                        new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                            .to(device)
 | 
			
		||||
                elif qtype == ggml_tensor_qtype["bf16"]:
 | 
			
		||||
                    module.to(torch.bfloat16)
 | 
			
		||||
                    new_linear = BF16Linear(
 | 
			
		||||
                        in_features,
 | 
			
		||||
                        out_features,
 | 
			
		||||
                        module.bias is not None,
 | 
			
		||||
                        mp_group=mp_group,
 | 
			
		||||
                        optimize_lm_head=optimize_lm_head
 | 
			
		||||
                    )
 | 
			
		||||
                    device = module.weight.data.device
 | 
			
		||||
                    # convert here
 | 
			
		||||
                    new_linear._parameters['weight'] = nn.Parameter(module.weight)
 | 
			
		||||
                    if module.bias is not None:
 | 
			
		||||
                        new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                            .to(device)
 | 
			
		||||
 | 
			
		||||
                    if new_linear is not None:
 | 
			
		||||
                        if not module.training:
 | 
			
		||||
                            new_linear.eval()
 | 
			
		||||
                        model._modules[name] = new_linear
 | 
			
		||||
                        has_been_replaced = True
 | 
			
		||||
                        # Force requires grad to False to avoid unexpected errors
 | 
			
		||||
                        model._modules[name].requires_grad_(False)
 | 
			
		||||
                if new_linear is not None:
 | 
			
		||||
                    if not module.training:
 | 
			
		||||
                        new_linear.eval()
 | 
			
		||||
                    model._modules[name] = new_linear
 | 
			
		||||
                    has_been_replaced = True
 | 
			
		||||
                    # Force requires grad to False to avoid unexpected errors
 | 
			
		||||
                    model._modules[name].requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
                        module.weight = None
 | 
			
		||||
                    module.weight = None
 | 
			
		||||
        elif cpu_embedding and type(module) == nn.Embedding:
 | 
			
		||||
            # skip user-defined Embedding layer
 | 
			
		||||
            model._modules[name] = LLMEmbedding(
 | 
			
		||||
| 
						 | 
				
			
			@ -375,7 +373,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                module,
 | 
			
		||||
                qtype,
 | 
			
		||||
                modules_to_not_convert,
 | 
			
		||||
                current_key_name,
 | 
			
		||||
                convert_shape_only,
 | 
			
		||||
                cpu_embedding,
 | 
			
		||||
                prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
 | 
			
		||||
| 
						 | 
				
			
			@ -664,7 +661,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
        model_type = None
 | 
			
		||||
    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
			
		||||
        model, qtype, modules_to_not_convert,
 | 
			
		||||
        None, convert_shape_only, cpu_embedding,
 | 
			
		||||
        convert_shape_only, cpu_embedding,
 | 
			
		||||
        imatrix_data=imatrix_data,
 | 
			
		||||
        embedding_qtype=embedding_qtype,
 | 
			
		||||
        model_type=model_type,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue