diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 05b21e0a..16abd567 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -739,6 +739,8 @@ def _optimize_pre(model, qtype=None): if model.config.model_type == "internlmxcomposer2": from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp model.apply(pre_process_attn_and_mlp) + if model.config.model_type == "internvl_chat": + _optimize_pre(model.language_model) if model.config.model_type == "gemma2": from ipex_llm.transformers.models.gemma2 import merge_qkv model.apply(merge_qkv) @@ -1268,6 +1270,7 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.internvl import _get_pos_embed vision_embedding = model.vision_model.embeddings vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding) + _optimize_post(model.language_model, lightweight_bmm=lightweight_bmm) elif model.config.model_type == "qwen": if hasattr(model.config, "visual"): # for Qwen-VL-Chat