fix rwkv v5 fp16 (#10474)

This commit is contained in:
Yishuo Wang 2024-03-20 13:15:08 +08:00 committed by GitHub
parent 72bcc27da9
commit 749bedaf1e

View file

@ -538,6 +538,7 @@ def _optimize_pre(model):
# for rwkv models (verified RWKV/rwkv-4-world-7b)
if model.config.model_type == "rwkv":
model.rwkv._rescale_layers()
model.rwkv.layers_are_rescaled = True
# process NormHead module in Baichuan2 7B and 13B
if model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
# NormHead do normalization on the weights just once at inference time.