fix rwkv v5 fp16 (#10474)
This commit is contained in:
parent
72bcc27da9
commit
749bedaf1e
1 changed files with 1 additions and 0 deletions
|
|
@ -538,6 +538,7 @@ def _optimize_pre(model):
|
|||
# for rwkv models (verified RWKV/rwkv-4-world-7b)
|
||||
if model.config.model_type == "rwkv":
|
||||
model.rwkv._rescale_layers()
|
||||
model.rwkv.layers_are_rescaled = True
|
||||
# process NormHead module in Baichuan2 7B and 13B
|
||||
if model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
|
||||
# NormHead do normalization on the weights just once at inference time.
|
||||
|
|
|
|||
Loading…
Reference in a new issue