optimize minicpm v 2.5 (#11793)

This commit is contained in:
Yishuo Wang 2024-08-14 16:07:24 +08:00 committed by GitHub
parent 356281cb80
commit 3d6cfa291d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 47 additions and 4 deletions

View file

@ -749,7 +749,7 @@ def _optimize_pre(model, qtype=None):
model.apply(merge_qkv)
if model.config.model_type == "minicpmv":
from ipex_llm.transformers.models.minicpmv import merge_qkv
model.apply(merge_qkv)
model.vpm.apply(merge_qkv)
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
model.llm.config.model_type = "qwen2"
_optimize_pre(model.llm, qtype=qtype)
@ -1742,9 +1742,13 @@ def _optimize_post(model, lightweight_bmm=False):
if model.vpm.config.model_type == "siglip":
# MiniCPM-V 2.6
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
convert_forward(model, vpm_module.SiglipAttention, siglip_attention_forward)
convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
elif model.vpm.config.model_type == "idefics2":
# MiniCPM-V 2.5
pass
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
from ipex_llm.transformers.models.minicpmv import minicpmv_chat_wrapper
convert_forward(model.vpm, vpm_module.Idefics2VisionAttention, siglip_attention_forward)
minicpmv_chat = minicpmv_chat_wrapper(module.MiniCPMV.chat)
model.chat = MethodType(minicpmv_chat, model)
return model

View file

@ -18,11 +18,13 @@
import torch
from typing import Optional
from ipex_llm.transformers.models.common import merge_qkv_base
from transformers import AutoProcessor
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
def merge_qkv(module: torch.nn.Module):
return merge_qkv_base(module, "SiglipAttention")
merge_qkv_base(module, "SiglipAttention")
merge_qkv_base(module, "Idefics2VisionAttention")
def siglip_attention_forward(
@ -67,6 +69,43 @@ def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: t
return scores
def minicpmv_chat_wrapper(origin_chat):
def minicpmv_chat(
self,
image,
msgs,
tokenizer,
processor=None,
vision_hidden_states=None,
max_new_tokens=1024,
sampling=True,
max_inp_length=2048,
system_prompt='',
stream=False,
**kwargs
):
if processor is None:
if getattr(self, "processor", None) is None:
self.processor = AutoProcessor.from_pretrained(self.config._name_or_path,
trust_remote_code=True)
processor = self.processor
return origin_chat(
self=self,
image=image,
msgs=msgs,
tokenizer=tokenizer,
processor=processor,
vision_hidden_states=vision_hidden_states,
max_new_tokens=max_new_tokens,
sampling=sampling,
max_inp_length=max_inp_length,
system_prompt=system_prompt,
stream=stream,
**kwargs
)
return minicpmv_chat
def minicpmv_generate_wrapper(origin_generate):
def generate(
*inputs,