From 1a6cbc473f9fe1863b3188da8df157892e9c2949 Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Thu, 7 Nov 2024 18:52:47 +0800 Subject: [PATCH] Add fused mlp optimizations to glm4 models (#12360) * Add fused mlp to glm4 models * Small fix --- .../llm/src/ipex_llm/transformers/convert.py | 18 +++++++++++------- .../ipex_llm/transformers/models/chatglm4.py | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 56231b20..57e034c9 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1055,13 +1055,15 @@ def _optimize_pre(model, qtype=None): # chatglm2 and chatglm3 from ipex_llm.transformers.models.chatglm2 import split_mlp model.apply(split_mlp) - elif ( - isinstance(model.config.eos_token_id, list) - and hasattr(model.transformer, "vision") - and model.config.num_layers != 40 - ): - from ipex_llm.transformers.models.chatglm4v import merge_qkv - model.apply(merge_qkv) + elif isinstance(model.config.eos_token_id, list): + # glm4 family + if hasattr(model.transformer, "vision"): + if model.config.num_layers != 40: + from ipex_llm.transformers.models.chatglm4v import merge_qkv + model.apply(merge_qkv) + elif model.config.num_layers in [40, 28]: + from ipex_llm.transformers.models.chatglm2 import split_mlp + model.apply(split_mlp) return model @@ -1463,9 +1465,11 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.chatglm4 import chatglm4_attention_forward from ipex_llm.transformers.models.chatglm4 import chatglm4_model_forward from ipex_llm.transformers.models.chatglm4 import chatglm4_encoder_forward + from ipex_llm.transformers.models.chatglm2 import mlp_forward convert_forward(model, module.SelfAttention, chatglm4_attention_forward) convert_forward(model, module.ChatGLMModel, chatglm4_model_forward) convert_forward(model, module.GLMTransformer, chatglm4_encoder_forward) + convert_forward(model, module.MLP, mlp_forward) elif "mpt" in model.config.model_type: if model.config.architectures is not None: modeling_module_name = model.__class__.__module__ diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index 282ce5bf..e3ba6bdf 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -56,7 +56,7 @@ def chatglm4_model_forward( if use_cache: inputs = input_ids if input_ids is not None else inputs_embeds use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) - use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h, + use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj, inputs) if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):