Add fused mlp optimizations to glm4 models (#12360)

* Add fused mlp to glm4 models

* Small fix
This commit is contained in:
Yuwen Hu 2024-11-07 18:52:47 +08:00 committed by GitHub
parent 520af4e9b5
commit 1a6cbc473f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 12 additions and 8 deletions

View file

@ -1055,13 +1055,15 @@ def _optimize_pre(model, qtype=None):
# chatglm2 and chatglm3 # chatglm2 and chatglm3
from ipex_llm.transformers.models.chatglm2 import split_mlp from ipex_llm.transformers.models.chatglm2 import split_mlp
model.apply(split_mlp) model.apply(split_mlp)
elif ( elif isinstance(model.config.eos_token_id, list):
isinstance(model.config.eos_token_id, list) # glm4 family
and hasattr(model.transformer, "vision") if hasattr(model.transformer, "vision"):
and model.config.num_layers != 40 if model.config.num_layers != 40:
):
from ipex_llm.transformers.models.chatglm4v import merge_qkv from ipex_llm.transformers.models.chatglm4v import merge_qkv
model.apply(merge_qkv) model.apply(merge_qkv)
elif model.config.num_layers in [40, 28]:
from ipex_llm.transformers.models.chatglm2 import split_mlp
model.apply(split_mlp)
return model return model
@ -1463,9 +1465,11 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.chatglm4 import chatglm4_attention_forward from ipex_llm.transformers.models.chatglm4 import chatglm4_attention_forward
from ipex_llm.transformers.models.chatglm4 import chatglm4_model_forward from ipex_llm.transformers.models.chatglm4 import chatglm4_model_forward
from ipex_llm.transformers.models.chatglm4 import chatglm4_encoder_forward from ipex_llm.transformers.models.chatglm4 import chatglm4_encoder_forward
from ipex_llm.transformers.models.chatglm2 import mlp_forward
convert_forward(model, module.SelfAttention, chatglm4_attention_forward) convert_forward(model, module.SelfAttention, chatglm4_attention_forward)
convert_forward(model, module.ChatGLMModel, chatglm4_model_forward) convert_forward(model, module.ChatGLMModel, chatglm4_model_forward)
convert_forward(model, module.GLMTransformer, chatglm4_encoder_forward) convert_forward(model, module.GLMTransformer, chatglm4_encoder_forward)
convert_forward(model, module.MLP, mlp_forward)
elif "mpt" in model.config.model_type: elif "mpt" in model.config.model_type:
if model.config.architectures is not None: if model.config.architectures is not None:
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__

View file

@ -56,7 +56,7 @@ def chatglm4_model_forward(
if use_cache: if use_cache:
inputs = input_ids if input_ids is not None else inputs_embeds inputs = input_ids if input_ids is not None else inputs_embeds
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h, use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
inputs) inputs)
if use_compress_kv and not isinstance(past_key_values, if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache): DynamicCompressCache):