diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index a6a319eb..ba53e0ba 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1028,6 +1028,15 @@ def _optimize_pre(model, qtype=None): model.llm.config.model_type = "minicpm" _optimize_pre(model.llm, qtype=qtype) model.llm.config.model_type = "minicpmv" + elif model.config.model_type == "minicpmo": + # vpm opt + from ipex_llm.transformers.models.minicpmv import merge_qkv + model.vpm.apply(merge_qkv) + + # llm opt + model.llm.config.model_type = "qwen2" + _optimize_pre(model.llm, qtype=qtype) + model.llm.config.model_type = "minicpmo" elif model.config.model_type == "megrezo": from ipex_llm.transformers.models.minicpmv import merge_qkv model.vision.apply(merge_qkv) @@ -1944,6 +1953,18 @@ def _optimize_post(model): convert_forward(model.vpm, vpm_module.Idefics2VisionAttention, siglip_attention_forward) minicpmv_chat = minicpmv_chat_wrapper(module.MiniCPMV.chat) model.chat = MethodType(minicpmv_chat, model) + elif model.config.model_type == "minicpmo": + # vpm opt + vpm_modeling_module_name = model.vpm.__class__.__module__ + vpm_module = importlib.import_module(vpm_modeling_module_name) + + from ipex_llm.transformers.models.minicpmv import siglip_attention_forward + convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward) + + # llm opt + model.llm.config.model_type = "qwen2" + _optimize_post(model.llm) + model.llm.config.model_type = "minicpmo" elif model.config.model_type == "megrezo": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name)