LLM: support baichuan2-13b using AutoTP (#10691)

This commit is contained in:
binbin Deng 2024-04-09 14:06:01 +08:00 committed by GitHub
parent c7422712fc
commit 44922bb5c2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1032,6 +1032,14 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, convert_forward(model,
module.MLP, module.MLP,
baichuan_mlp_forward) baichuan_mlp_forward)
if hasattr(model.model, 'get_alibi_mask_orig'):
# deepspeed rewrite "get_alibi_mask" to support baichuan
# https://github.com/microsoft/DeepSpeed/pull/4721
replace_func(model,
module.BaichuanModel,
"get_alibi_mask_orig",
baichuan_13b_get_alibi_mask)
else:
replace_func(model, replace_func(model,
module.BaichuanModel, module.BaichuanModel,
"get_alibi_mask", "get_alibi_mask",