diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index c3cdde4a..1f95ccbe 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -561,6 +561,10 @@ class _BaseAutoModelClass: else: model = model_class(config, *model_args, **kwargs) + # rwkv model linear layers has been rescaled + if model.config.model_type == "rwkv": + model.layers_are_rescaled = True + # Loading args may differ based on their usage quant_device = "meta" if bigdl_lcmu_enabled else "cpu" model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,