optimize minicpm v 2.5 (#11793)
This commit is contained in:
parent
356281cb80
commit
3d6cfa291d
2 changed files with 47 additions and 4 deletions
|
|
@ -749,7 +749,7 @@ def _optimize_pre(model, qtype=None):
|
|||
model.apply(merge_qkv)
|
||||
if model.config.model_type == "minicpmv":
|
||||
from ipex_llm.transformers.models.minicpmv import merge_qkv
|
||||
model.apply(merge_qkv)
|
||||
model.vpm.apply(merge_qkv)
|
||||
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
||||
model.llm.config.model_type = "qwen2"
|
||||
_optimize_pre(model.llm, qtype=qtype)
|
||||
|
|
@ -1742,9 +1742,13 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
if model.vpm.config.model_type == "siglip":
|
||||
# MiniCPM-V 2.6
|
||||
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
|
||||
convert_forward(model, vpm_module.SiglipAttention, siglip_attention_forward)
|
||||
convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward)
|
||||
elif model.vpm.config.model_type == "idefics2":
|
||||
# MiniCPM-V 2.5
|
||||
pass
|
||||
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
|
||||
from ipex_llm.transformers.models.minicpmv import minicpmv_chat_wrapper
|
||||
convert_forward(model.vpm, vpm_module.Idefics2VisionAttention, siglip_attention_forward)
|
||||
minicpmv_chat = minicpmv_chat_wrapper(module.MiniCPMV.chat)
|
||||
model.chat = MethodType(minicpmv_chat, model)
|
||||
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -18,11 +18,13 @@
|
|||
import torch
|
||||
from typing import Optional
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||
from transformers import AutoProcessor
|
||||
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
|
||||
|
||||
|
||||
def merge_qkv(module: torch.nn.Module):
|
||||
return merge_qkv_base(module, "SiglipAttention")
|
||||
merge_qkv_base(module, "SiglipAttention")
|
||||
merge_qkv_base(module, "Idefics2VisionAttention")
|
||||
|
||||
|
||||
def siglip_attention_forward(
|
||||
|
|
@ -67,6 +69,43 @@ def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: t
|
|||
return scores
|
||||
|
||||
|
||||
def minicpmv_chat_wrapper(origin_chat):
|
||||
def minicpmv_chat(
|
||||
self,
|
||||
image,
|
||||
msgs,
|
||||
tokenizer,
|
||||
processor=None,
|
||||
vision_hidden_states=None,
|
||||
max_new_tokens=1024,
|
||||
sampling=True,
|
||||
max_inp_length=2048,
|
||||
system_prompt='',
|
||||
stream=False,
|
||||
**kwargs
|
||||
):
|
||||
if processor is None:
|
||||
if getattr(self, "processor", None) is None:
|
||||
self.processor = AutoProcessor.from_pretrained(self.config._name_or_path,
|
||||
trust_remote_code=True)
|
||||
processor = self.processor
|
||||
return origin_chat(
|
||||
self=self,
|
||||
image=image,
|
||||
msgs=msgs,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
vision_hidden_states=vision_hidden_states,
|
||||
max_new_tokens=max_new_tokens,
|
||||
sampling=sampling,
|
||||
max_inp_length=max_inp_length,
|
||||
system_prompt=system_prompt,
|
||||
stream=stream,
|
||||
**kwargs
|
||||
)
|
||||
return minicpmv_chat
|
||||
|
||||
|
||||
def minicpmv_generate_wrapper(origin_generate):
|
||||
def generate(
|
||||
*inputs,
|
||||
|
|
|
|||
Loading…
Reference in a new issue