LLM: add full module name during convert (#10035)

This commit is contained in:
Ruonan Wang 2024-01-30 14:43:07 +08:00 committed by GitHub
parent 7dfa6dbe46
commit 6b63ba23d1

View file

@ -190,7 +190,7 @@ def convert_gptq(module, awq=False, llm_awq=False):
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name=None, convert_shape_only=False, current_key_name=None, convert_shape_only=False,
cpu_embedding=False): cpu_embedding=False, prefix_name=''):
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \ from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
FP16Linear, BF16Linear FP16Linear, BF16Linear
from bigdl.llm.transformers.embedding import LLMEmbedding from bigdl.llm.transformers.embedding import LLMEmbedding
@ -201,7 +201,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name = [] current_key_name = []
is_linear, linear_args = is_linear_module(module) is_linear, linear_args = is_linear_module(module)
if is_linear and name not in modules_to_not_convert: full_module_name = prefix_name + '.' + name if prefix_name != '' else name
if is_linear and name not in modules_to_not_convert and \
full_module_name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert` # Check if the current key is not in the `modules_to_not_convert`
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
not isinstance(module, LowBitLinear)): not isinstance(module, LowBitLinear)):
@ -323,6 +325,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name, current_key_name,
convert_shape_only, convert_shape_only,
cpu_embedding, cpu_embedding,
prefix_name=prefix_name + '.' + name if prefix_name != '' else name
) )
has_been_replaced = _flag or has_been_replaced has_been_replaced = _flag or has_been_replaced
return model, has_been_replaced return model, has_been_replaced