diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 9fefd634..02e7f575 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1299,7 +1299,12 @@ def _optimize_post(model, lightweight_bmm=False): setattr(model.model.layers[i].self_attn, "layer_idx", i) convert_forward(model, module.Attention, baichuan_attention_forward_7b) convert_forward(model, module.RMSNorm, llama_rms_norm_forward) - convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward) + if model.config.vocab_size == 125696: + # baichuan2-7B + convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward) + elif model.config.vocab_size == 64000: + # baichuan-7B + convert_forward(model, module.Model, baichuan_model_7b_forward) elif model.config.hidden_size == 5120: # baichuan-13B and baichuan2-13B from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b