diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 7edb2b7a..b8c68995 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1755,19 +1755,29 @@ def _optimize_post(model, lightweight_bmm=False): module.MiniCPMModel, minicpm_model_forward) elif model.config.model_type == "minicpmv": - if model.config.hidden_size == 3584 and model.config.vocab_size == 151666: - model.llm.config.model_type = "qwen2" - _optimize_post(model.llm, lightweight_bmm=lightweight_bmm) - model.llm.config.model_type = "minicpmv" modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.minicpmv import minicpmv_generate_wrapper minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate) model.generate = MethodType(minicpmv_generate, model) - modeling_module_name = model.vpm.__class__.__module__ - module = importlib.import_module(modeling_module_name) - from ipex_llm.transformers.models.minicpmv import siglip_attention_forward - convert_forward(model, module.SiglipAttention, siglip_attention_forward) + if model.config.hidden_size == 3584 and model.config.vocab_size == 151666: + # MiniCPM-V 2.6 + model.llm.config.model_type = "qwen2" + _optimize_post(model.llm, lightweight_bmm=lightweight_bmm) + model.llm.config.model_type = "minicpmv" + elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256: + # MiniCPM-V 2.5 + pass + + vpm_modeling_module_name = model.vpm.__class__.__module__ + vpm_module = importlib.import_module(vpm_modeling_module_name) + if model.vpm.config.model_type == "siglip": + # MiniCPM-V 2.6 + from ipex_llm.transformers.models.minicpmv import siglip_attention_forward + convert_forward(model, vpm_module.SiglipAttention, siglip_attention_forward) + elif model.vpm.config.model_type == "idefics2": + # MiniCPM-V 2.5 + pass return model diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index 159e61b6..03d8d2f4 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -69,32 +69,12 @@ def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: t def minicpmv_generate_wrapper(origin_generate): def generate( - self, - input_ids=None, - pixel_values=None, - tgt_sizes=None, - image_bound=None, - attention_mask=None, - tokenizer=None, - vision_hidden_states=None, - return_vision_hidden_states=False, - stream=False, - decode_text=False, + *inputs, **kwargs ): RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call return origin_generate( - self=self, - input_ids=input_ids, - pixel_values=pixel_values, - tgt_sizes=tgt_sizes, - image_bound=image_bound, - attention_mask=attention_mask, - tokenizer=tokenizer, - vision_hidden_states=vision_hidden_states, - return_vision_hidden_states=return_vision_hidden_states, - stream=stream, - decode_text=decode_text, - **kwargs + *inputs, + **kwargs, ) return generate