From 4fbf449c2da0cc308fd5a717daa628866527b689 Mon Sep 17 00:00:00 2001 From: Zhao Changmin Date: Wed, 21 Feb 2024 10:11:10 +0800 Subject: [PATCH] for rwkv4 (#10179) --- python/llm/src/bigdl/llm/transformers/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index ec75aee7..89dcee8d 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -582,10 +582,6 @@ class _BaseAutoModelClass: else: model = model_class(config, *model_args, **kwargs) - # rwkv model linear layers has been rescaled - if model.config.model_type == "rwkv": - model.layers_are_rescaled = True - # Loading args may differ based on their usage quant_device = "meta" if bigdl_lcmu_enabled else "cpu" model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device, @@ -647,6 +643,10 @@ class _BaseAutoModelClass: pass for param in model.parameters(): param.requires_grad_(False) + + # rwkv model linear layers has been rescaled + if model.config.model_type == "rwkv": + model.rwkv.layers_are_rescaled = True return model