LLM: add full module name during convert (#10035)
This commit is contained in:
parent
7dfa6dbe46
commit
6b63ba23d1
1 changed files with 5 additions and 2 deletions
|
|
@ -190,7 +190,7 @@ def convert_gptq(module, awq=False, llm_awq=False):
|
|||
|
||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||
current_key_name=None, convert_shape_only=False,
|
||||
cpu_embedding=False):
|
||||
cpu_embedding=False, prefix_name=''):
|
||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
||||
FP16Linear, BF16Linear
|
||||
from bigdl.llm.transformers.embedding import LLMEmbedding
|
||||
|
|
@ -201,7 +201,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
current_key_name = []
|
||||
|
||||
is_linear, linear_args = is_linear_module(module)
|
||||
if is_linear and name not in modules_to_not_convert:
|
||||
full_module_name = prefix_name + '.' + name if prefix_name != '' else name
|
||||
if is_linear and name not in modules_to_not_convert and \
|
||||
full_module_name not in modules_to_not_convert:
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
|
||||
not isinstance(module, LowBitLinear)):
|
||||
|
|
@ -323,6 +325,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
current_key_name,
|
||||
convert_shape_only,
|
||||
cpu_embedding,
|
||||
prefix_name=prefix_name + '.' + name if prefix_name != '' else name
|
||||
)
|
||||
has_been_replaced = _flag or has_been_replaced
|
||||
return model, has_been_replaced
|
||||
|
|
|
|||
Loading…
Reference in a new issue