Fix modules_not_to_convert argument (#10483)

This commit is contained in:
Yishuo Wang 2024-03-20 17:47:03 +08:00 committed by GitHub
parent cbe24cc7e6
commit cfdf8ad496

View file

@ -189,7 +189,7 @@ def convert_gptq(module, awq=False, llm_awq=False):
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name=None, convert_shape_only=False,
convert_shape_only=False,
cpu_embedding=False, prefix_name='',
imatrix_data=None, embedding_qtype=None,
model_type=None, torch_dtype=torch.float32,
@ -200,16 +200,14 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
has_been_replaced = False
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
is_linear, linear_args = is_linear_module(module)
full_module_name = prefix_name + '.' + name if prefix_name != '' else name
if is_linear and name not in modules_to_not_convert and \
full_module_name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
not isinstance(module, LowBitLinear)):
# use sub-string to match, it may match `10` if user only pass a number like `0`
if any(key in full_module_name for key in modules_to_not_convert):
continue
if is_linear and not isinstance(module, LowBitLinear):
in_features, out_features, mp_group = linear_args
optimize_lm_head = False
if name == "lm_head":
@ -375,7 +373,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
module,
qtype,
modules_to_not_convert,
current_key_name,
convert_shape_only,
cpu_embedding,
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
@ -664,7 +661,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
model_type = None
model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert,
None, convert_shape_only, cpu_embedding,
convert_shape_only, cpu_embedding,
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
model_type=model_type,