LLM: support baichuan2-13b using AutoTP (#10691)

This commit is contained in:
binbin Deng 2024-04-09 14:06:01 +08:00 committed by GitHub
parent c7422712fc
commit 44922bb5c2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1032,10 +1032,18 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, convert_forward(model,
module.MLP, module.MLP,
baichuan_mlp_forward) baichuan_mlp_forward)
replace_func(model, if hasattr(model.model, 'get_alibi_mask_orig'):
module.BaichuanModel, # deepspeed rewrite "get_alibi_mask" to support baichuan
"get_alibi_mask", # https://github.com/microsoft/DeepSpeed/pull/4721
baichuan_13b_get_alibi_mask) replace_func(model,
module.BaichuanModel,
"get_alibi_mask_orig",
baichuan_13b_get_alibi_mask)
else:
replace_func(model,
module.BaichuanModel,
"get_alibi_mask",
baichuan_13b_get_alibi_mask)
elif model.config.model_type == "baichuan": elif model.config.model_type == "baichuan":
# baichuan1 # baichuan1
if model.config.hidden_size == 4096: if model.config.hidden_size == 4096: