for rwkv4 (#10179)
This commit is contained in:
parent
de3dc609ee
commit
4fbf449c2d
1 changed files with 4 additions and 4 deletions
|
|
@ -582,10 +582,6 @@ class _BaseAutoModelClass:
|
||||||
else:
|
else:
|
||||||
model = model_class(config, *model_args, **kwargs)
|
model = model_class(config, *model_args, **kwargs)
|
||||||
|
|
||||||
# rwkv model linear layers has been rescaled
|
|
||||||
if model.config.model_type == "rwkv":
|
|
||||||
model.layers_are_rescaled = True
|
|
||||||
|
|
||||||
# Loading args may differ based on their usage
|
# Loading args may differ based on their usage
|
||||||
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
|
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
|
||||||
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
|
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
|
||||||
|
|
@ -647,6 +643,10 @@ class _BaseAutoModelClass:
|
||||||
pass
|
pass
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
|
||||||
|
# rwkv model linear layers has been rescaled
|
||||||
|
if model.config.model_type == "rwkv":
|
||||||
|
model.rwkv.layers_are_rescaled = True
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue