diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 57e034c9..fa164d87 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1056,13 +1056,14 @@ def _optimize_pre(model, qtype=None): from ipex_llm.transformers.models.chatglm2 import split_mlp model.apply(split_mlp) elif isinstance(model.config.eos_token_id, list): + from ipex_llm.transformers.models.chatglm2 import split_mlp # glm4 family if hasattr(model.transformer, "vision"): if model.config.num_layers != 40: from ipex_llm.transformers.models.chatglm4v import merge_qkv model.apply(merge_qkv) + model.apply(split_mlp) elif model.config.num_layers in [40, 28]: - from ipex_llm.transformers.models.chatglm2 import split_mlp model.apply(split_mlp) return model @@ -1459,6 +1460,8 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, SiglipAttention, siglip_attention_forward) from ipex_llm.transformers.models.chatglm4v import vision_model_forward convert_forward(model, vision_module.VisionModel, vision_model_forward) + from ipex_llm.transformers.models.chatglm2 import mlp_forward + convert_forward(model, module.MLP, mlp_forward) elif model.config.num_layers in [40, 28]: # glm-4-9b