LLM: support baichuan2-13b using AutoTP (#10691)
This commit is contained in:
parent
c7422712fc
commit
44922bb5c2
1 changed files with 12 additions and 4 deletions
|
|
@ -1032,10 +1032,18 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.MLP,
|
module.MLP,
|
||||||
baichuan_mlp_forward)
|
baichuan_mlp_forward)
|
||||||
replace_func(model,
|
if hasattr(model.model, 'get_alibi_mask_orig'):
|
||||||
module.BaichuanModel,
|
# deepspeed rewrite "get_alibi_mask" to support baichuan
|
||||||
"get_alibi_mask",
|
# https://github.com/microsoft/DeepSpeed/pull/4721
|
||||||
baichuan_13b_get_alibi_mask)
|
replace_func(model,
|
||||||
|
module.BaichuanModel,
|
||||||
|
"get_alibi_mask_orig",
|
||||||
|
baichuan_13b_get_alibi_mask)
|
||||||
|
else:
|
||||||
|
replace_func(model,
|
||||||
|
module.BaichuanModel,
|
||||||
|
"get_alibi_mask",
|
||||||
|
baichuan_13b_get_alibi_mask)
|
||||||
elif model.config.model_type == "baichuan":
|
elif model.config.model_type == "baichuan":
|
||||||
# baichuan1
|
# baichuan1
|
||||||
if model.config.hidden_size == 4096:
|
if model.config.hidden_size == 4096:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue