diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index ec75aee7..89dcee8d 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -582,10 +582,6 @@ class _BaseAutoModelClass: else: model = model_class(config, *model_args, **kwargs) - # rwkv model linear layers has been rescaled - if model.config.model_type == "rwkv": - model.layers_are_rescaled = True - # Loading args may differ based on their usage quant_device = "meta" if bigdl_lcmu_enabled else "cpu" model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device, @@ -647,6 +643,10 @@ class _BaseAutoModelClass: pass for param in model.parameters(): param.requires_grad_(False) + + # rwkv model linear layers has been rescaled + if model.config.model_type == "rwkv": + model.rwkv.layers_are_rescaled = True return model