update baichuan-7b

This commit is contained in:
Huang, Xinshengzi 2024-08-22 18:16:33 +08:00
parent eb1e65f8a9
commit 4cf03d6212

View file

@ -1299,7 +1299,12 @@ def _optimize_post(model, lightweight_bmm=False):
setattr(model.model.layers[i].self_attn, "layer_idx", i)
convert_forward(model, module.Attention, baichuan_attention_forward_7b)
convert_forward(model, module.RMSNorm, llama_rms_norm_forward)
convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward)
if model.config.vocab_size == 125696:
# baichuan2-7B
convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward)
elif model.config.vocab_size == 64000:
# baichuan-7B
convert_forward(model, module.Model, baichuan_model_7b_forward)
elif model.config.hidden_size == 5120:
# baichuan-13B and baichuan2-13B
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b