Fix modules_not_to_convert argument (#10483)
				
					
				
			This commit is contained in:
		
							parent
							
								
									cbe24cc7e6
								
							
						
					
					
						commit
						cfdf8ad496
					
				
					 1 changed files with 122 additions and 125 deletions
				
			
		| 
						 | 
					@ -189,7 +189,7 @@ def convert_gptq(module, awq=False, llm_awq=False):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
					def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                                 current_key_name=None, convert_shape_only=False,
 | 
					                                 convert_shape_only=False,
 | 
				
			||||||
                                 cpu_embedding=False, prefix_name='',
 | 
					                                 cpu_embedding=False, prefix_name='',
 | 
				
			||||||
                                 imatrix_data=None, embedding_qtype=None,
 | 
					                                 imatrix_data=None, embedding_qtype=None,
 | 
				
			||||||
                                 model_type=None, torch_dtype=torch.float32,
 | 
					                                 model_type=None, torch_dtype=torch.float32,
 | 
				
			||||||
| 
						 | 
					@ -200,16 +200,14 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
    has_been_replaced = False
 | 
					    has_been_replaced = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for name, module in model.named_children():
 | 
					    for name, module in model.named_children():
 | 
				
			||||||
        if current_key_name is None:
 | 
					 | 
				
			||||||
            current_key_name = []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        is_linear, linear_args = is_linear_module(module)
 | 
					        is_linear, linear_args = is_linear_module(module)
 | 
				
			||||||
        full_module_name = prefix_name + '.' + name if prefix_name != '' else name
 | 
					        full_module_name = prefix_name + '.' + name if prefix_name != '' else name
 | 
				
			||||||
        if is_linear and name not in modules_to_not_convert and \
 | 
					
 | 
				
			||||||
                full_module_name not in modules_to_not_convert:
 | 
					        # use sub-string to match, it may match `10` if user only pass a number like `0`
 | 
				
			||||||
            # Check if the current key is not in the `modules_to_not_convert`
 | 
					        if any(key in full_module_name for key in modules_to_not_convert):
 | 
				
			||||||
            if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
 | 
					            continue
 | 
				
			||||||
                    not isinstance(module, LowBitLinear)):
 | 
					
 | 
				
			||||||
 | 
					        if is_linear and not isinstance(module, LowBitLinear):
 | 
				
			||||||
            in_features, out_features, mp_group = linear_args
 | 
					            in_features, out_features, mp_group = linear_args
 | 
				
			||||||
            optimize_lm_head = False
 | 
					            optimize_lm_head = False
 | 
				
			||||||
            if name == "lm_head":
 | 
					            if name == "lm_head":
 | 
				
			||||||
| 
						 | 
					@ -375,7 +373,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                module,
 | 
					                module,
 | 
				
			||||||
                qtype,
 | 
					                qtype,
 | 
				
			||||||
                modules_to_not_convert,
 | 
					                modules_to_not_convert,
 | 
				
			||||||
                current_key_name,
 | 
					 | 
				
			||||||
                convert_shape_only,
 | 
					                convert_shape_only,
 | 
				
			||||||
                cpu_embedding,
 | 
					                cpu_embedding,
 | 
				
			||||||
                prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
 | 
					                prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
 | 
				
			||||||
| 
						 | 
					@ -664,7 +661,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
				
			||||||
        model_type = None
 | 
					        model_type = None
 | 
				
			||||||
    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
					    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
				
			||||||
        model, qtype, modules_to_not_convert,
 | 
					        model, qtype, modules_to_not_convert,
 | 
				
			||||||
        None, convert_shape_only, cpu_embedding,
 | 
					        convert_shape_only, cpu_embedding,
 | 
				
			||||||
        imatrix_data=imatrix_data,
 | 
					        imatrix_data=imatrix_data,
 | 
				
			||||||
        embedding_qtype=embedding_qtype,
 | 
					        embedding_qtype=embedding_qtype,
 | 
				
			||||||
        model_type=model_type,
 | 
					        model_type=model_type,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue