LLM: support baichuan2-13b using AutoTP (#10691)
This commit is contained in:
parent
c7422712fc
commit
44922bb5c2
1 changed files with 12 additions and 4 deletions
|
|
@ -1032,10 +1032,18 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model,
|
||||
module.MLP,
|
||||
baichuan_mlp_forward)
|
||||
replace_func(model,
|
||||
module.BaichuanModel,
|
||||
"get_alibi_mask",
|
||||
baichuan_13b_get_alibi_mask)
|
||||
if hasattr(model.model, 'get_alibi_mask_orig'):
|
||||
# deepspeed rewrite "get_alibi_mask" to support baichuan
|
||||
# https://github.com/microsoft/DeepSpeed/pull/4721
|
||||
replace_func(model,
|
||||
module.BaichuanModel,
|
||||
"get_alibi_mask_orig",
|
||||
baichuan_13b_get_alibi_mask)
|
||||
else:
|
||||
replace_func(model,
|
||||
module.BaichuanModel,
|
||||
"get_alibi_mask",
|
||||
baichuan_13b_get_alibi_mask)
|
||||
elif model.config.model_type == "baichuan":
|
||||
# baichuan1
|
||||
if model.config.hidden_size == 4096:
|
||||
|
|
|
|||
Loading…
Reference in a new issue