add fused rms optimization for chatglm model (#9256)
This commit is contained in:
parent
b15656229e
commit
bfc1e2d733
1 changed files with 3 additions and 0 deletions
|
|
@ -233,6 +233,9 @@ def _optimize_post(model):
|
|||
convert_forward(model,
|
||||
module.CoreAttention,
|
||||
core_attn_forward_8eb45c)
|
||||
convert_forward(model,
|
||||
module.RMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
elif hasattr(model.config, 'vocab_size') and model.config.vocab_size == 130528:
|
||||
# chatglm-6b
|
||||
modeling_module_name = model.__class__.__module__
|
||||
|
|
|
|||
Loading…
Reference in a new issue