diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 6db25be0..9fa62956 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1249,7 +1249,8 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.qwen import qwen_mlp_forward from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward from ipex_llm.transformers.models.qwen import qwen_model_forward - if model.config.max_position_embeddings == 8192: + if model.config.max_position_embeddings == 8192 \ + and model.config.hidden_size == 4096: convert_forward(model, module.QWenAttention, qwen_attention_forward_registered