add fused rms optimization for chatglm model (#9256)
This commit is contained in:
parent
b15656229e
commit
bfc1e2d733
1 changed files with 3 additions and 0 deletions
|
|
@ -233,6 +233,9 @@ def _optimize_post(model):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.CoreAttention,
|
module.CoreAttention,
|
||||||
core_attn_forward_8eb45c)
|
core_attn_forward_8eb45c)
|
||||||
|
convert_forward(model,
|
||||||
|
module.RMSNorm,
|
||||||
|
llama_rms_norm_forward)
|
||||||
elif hasattr(model.config, 'vocab_size') and model.config.vocab_size == 130528:
|
elif hasattr(model.config, 'vocab_size') and model.config.vocab_size == 130528:
|
||||||
# chatglm-6b
|
# chatglm-6b
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue