LLM: add full module name during convert (#10035)
This commit is contained in:
parent
7dfa6dbe46
commit
6b63ba23d1
1 changed files with 5 additions and 2 deletions
|
|
@ -190,7 +190,7 @@ def convert_gptq(module, awq=False, llm_awq=False):
|
||||||
|
|
||||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
current_key_name=None, convert_shape_only=False,
|
current_key_name=None, convert_shape_only=False,
|
||||||
cpu_embedding=False):
|
cpu_embedding=False, prefix_name=''):
|
||||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
||||||
FP16Linear, BF16Linear
|
FP16Linear, BF16Linear
|
||||||
from bigdl.llm.transformers.embedding import LLMEmbedding
|
from bigdl.llm.transformers.embedding import LLMEmbedding
|
||||||
|
|
@ -201,7 +201,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
current_key_name = []
|
current_key_name = []
|
||||||
|
|
||||||
is_linear, linear_args = is_linear_module(module)
|
is_linear, linear_args = is_linear_module(module)
|
||||||
if is_linear and name not in modules_to_not_convert:
|
full_module_name = prefix_name + '.' + name if prefix_name != '' else name
|
||||||
|
if is_linear and name not in modules_to_not_convert and \
|
||||||
|
full_module_name not in modules_to_not_convert:
|
||||||
# Check if the current key is not in the `modules_to_not_convert`
|
# Check if the current key is not in the `modules_to_not_convert`
|
||||||
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
|
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
|
||||||
not isinstance(module, LowBitLinear)):
|
not isinstance(module, LowBitLinear)):
|
||||||
|
|
@ -323,6 +325,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
current_key_name,
|
current_key_name,
|
||||||
convert_shape_only,
|
convert_shape_only,
|
||||||
cpu_embedding,
|
cpu_embedding,
|
||||||
|
prefix_name=prefix_name + '.' + name if prefix_name != '' else name
|
||||||
)
|
)
|
||||||
has_been_replaced = _flag or has_been_replaced
|
has_been_replaced = _flag or has_been_replaced
|
||||||
return model, has_been_replaced
|
return model, has_been_replaced
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue