LLM: support baichuan2-13b using AutoTP (#10691)

This commit is contained in:
binbin Deng 2024-04-09 14:06:01 +08:00 committed by GitHub
parent c7422712fc
commit 44922bb5c2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1032,10 +1032,18 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.MLP,
baichuan_mlp_forward)
replace_func(model,
module.BaichuanModel,
"get_alibi_mask",
baichuan_13b_get_alibi_mask)
if hasattr(model.model, 'get_alibi_mask_orig'):
# deepspeed rewrite "get_alibi_mask" to support baichuan
# https://github.com/microsoft/DeepSpeed/pull/4721
replace_func(model,
module.BaichuanModel,
"get_alibi_mask_orig",
baichuan_13b_get_alibi_mask)
else:
replace_func(model,
module.BaichuanModel,
"get_alibi_mask",
baichuan_13b_get_alibi_mask)
elif model.config.model_type == "baichuan":
# baichuan1
if model.config.hidden_size == 4096: