Fix modules_not_to_convert argument (#10483)

This commit is contained in:
Yishuo Wang 2024-03-20 17:47:03 +08:00 committed by GitHub
parent cbe24cc7e6
commit cfdf8ad496

View file

@ -189,7 +189,7 @@ def convert_gptq(module, awq=False, llm_awq=False):
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name=None, convert_shape_only=False, convert_shape_only=False,
cpu_embedding=False, prefix_name='', cpu_embedding=False, prefix_name='',
imatrix_data=None, embedding_qtype=None, imatrix_data=None, embedding_qtype=None,
model_type=None, torch_dtype=torch.float32, model_type=None, torch_dtype=torch.float32,
@ -200,16 +200,14 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
has_been_replaced = False has_been_replaced = False
for name, module in model.named_children(): for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
is_linear, linear_args = is_linear_module(module) is_linear, linear_args = is_linear_module(module)
full_module_name = prefix_name + '.' + name if prefix_name != '' else name full_module_name = prefix_name + '.' + name if prefix_name != '' else name
if is_linear and name not in modules_to_not_convert and \
full_module_name not in modules_to_not_convert: # use sub-string to match, it may match `10` if user only pass a number like `0`
# Check if the current key is not in the `modules_to_not_convert` if any(key in full_module_name for key in modules_to_not_convert):
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and continue
not isinstance(module, LowBitLinear)):
if is_linear and not isinstance(module, LowBitLinear):
in_features, out_features, mp_group = linear_args in_features, out_features, mp_group = linear_args
optimize_lm_head = False optimize_lm_head = False
if name == "lm_head": if name == "lm_head":
@ -375,7 +373,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
module, module,
qtype, qtype,
modules_to_not_convert, modules_to_not_convert,
current_key_name,
convert_shape_only, convert_shape_only,
cpu_embedding, cpu_embedding,
prefix_name=prefix_name + '.' + name if prefix_name != '' else name, prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
@ -664,7 +661,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
model_type = None model_type = None
model, has_been_replaced = _replace_with_low_bit_linear( model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert, model, qtype, modules_to_not_convert,
None, convert_shape_only, cpu_embedding, convert_shape_only, cpu_embedding,
imatrix_data=imatrix_data, imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype, embedding_qtype=embedding_qtype,
model_type=model_type, model_type=model_type,