From 872a74481a321bfa77cc5addcc9df14c99fbc6fd Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Wed, 6 Nov 2024 19:16:58 +0800 Subject: [PATCH] Small optimization to glm4 models (#12351) --- python/llm/src/ipex_llm/transformers/convert.py | 2 +- python/llm/src/ipex_llm/transformers/models/chatglm4.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index cd7d9101..56231b20 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1458,7 +1458,7 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.chatglm4v import vision_model_forward convert_forward(model, vision_module.VisionModel, vision_model_forward) - elif model.config.num_layers == 40: + elif model.config.num_layers in [40, 28]: # glm-4-9b from ipex_llm.transformers.models.chatglm4 import chatglm4_attention_forward from ipex_llm.transformers.models.chatglm4 import chatglm4_model_forward diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index 2daffedb..282ce5bf 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -44,6 +44,7 @@ def chatglm4_model_forward( use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else