diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index c11a6830..36950230 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -190,7 +190,7 @@ def convert_gptq(module, awq=False, llm_awq=False): def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, current_key_name=None, convert_shape_only=False, - cpu_embedding=False): + cpu_embedding=False, prefix_name=''): from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \ FP16Linear, BF16Linear from bigdl.llm.transformers.embedding import LLMEmbedding @@ -201,7 +201,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, current_key_name = [] is_linear, linear_args = is_linear_module(module) - if is_linear and name not in modules_to_not_convert: + full_module_name = prefix_name + '.' + name if prefix_name != '' else name + if is_linear and name not in modules_to_not_convert and \ + full_module_name not in modules_to_not_convert: # Check if the current key is not in the `modules_to_not_convert` if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and not isinstance(module, LowBitLinear)): @@ -323,6 +325,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, current_key_name, convert_shape_only, cpu_embedding, + prefix_name=prefix_name + '.' + name if prefix_name != '' else name ) has_been_replaced = _flag or has_been_replaced return model, has_been_replaced