This commit is contained in:
Zhao Changmin 2024-02-21 10:11:10 +08:00 committed by GitHub
parent de3dc609ee
commit 4fbf449c2d

View file

@ -582,10 +582,6 @@ class _BaseAutoModelClass:
else:
model = model_class(config, *model_args, **kwargs)
# rwkv model linear layers has been rescaled
if model.config.model_type == "rwkv":
model.layers_are_rescaled = True
# Loading args may differ based on their usage
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
@ -647,6 +643,10 @@ class _BaseAutoModelClass:
pass
for param in model.parameters():
param.requires_grad_(False)
# rwkv model linear layers has been rescaled
if model.config.model_type == "rwkv":
model.rwkv.layers_are_rescaled = True
return model