diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 9372f667..6d81c677 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1032,10 +1032,18 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.MLP, baichuan_mlp_forward) - replace_func(model, - module.BaichuanModel, - "get_alibi_mask", - baichuan_13b_get_alibi_mask) + if hasattr(model.model, 'get_alibi_mask_orig'): + # deepspeed rewrite "get_alibi_mask" to support baichuan + # https://github.com/microsoft/DeepSpeed/pull/4721 + replace_func(model, + module.BaichuanModel, + "get_alibi_mask_orig", + baichuan_13b_get_alibi_mask) + else: + replace_func(model, + module.BaichuanModel, + "get_alibi_mask", + baichuan_13b_get_alibi_mask) elif model.config.model_type == "baichuan": # baichuan1 if model.config.hidden_size == 4096: