Support MiniCPM-V-2_6 multi-modal benchmarking with latency text streamer (#11963)

* Support MiniCPM-V-2_6 multi-modal benchmarking with latency text streamer

* Style fixes
This commit is contained in:
Yuwen Hu 2024-08-29 19:22:09 +08:00 committed by GitHub
parent 2e49e1f8e9
commit a9e485eb1b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 51 additions and 1 deletions

View file

@ -1997,6 +1997,11 @@ def _optimize_post(model, lightweight_bmm=False):
resampler_module_name = model.resampler.__class__.__module__ resampler_module_name = model.resampler.__class__.__module__
resampler_module = importlib.import_module(resampler_module_name) resampler_module = importlib.import_module(resampler_module_name)
resampler_module._in_projection_packed = _in_projection_packed resampler_module._in_projection_packed = _in_projection_packed
# for minicpm-v-2_6 benchmarking purposes
from ipex_llm.transformers.models.minicpmv import minicpmv_decode_stream_wrapper
minicpmv_decode_stream = minicpmv_decode_stream_wrapper(module.MiniCPMV._decode_stream)
model._decode_stream = MethodType(minicpmv_decode_stream, model)
elif model.vpm.config.model_type == "idefics2": elif model.vpm.config.model_type == "idefics2":
# MiniCPM-V 2.5 # MiniCPM-V 2.5
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward from ipex_llm.transformers.models.minicpmv import siglip_attention_forward

View file

@ -13,15 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
# Some parts of this file is adapted from
# https://huggingface.co/openbmb/MiniCPM-V-2_6/blob/main/modeling_minicpmv.py
# which is licensed under Apache License 2.0:
#
# https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE
#
import math import math
import torch import torch
from threading import Thread
from typing import Optional, List from typing import Optional, List
from torch.nn.functional import linear from torch.nn.functional import linear
from ipex_llm.transformers.models.common import merge_qkv_base from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.common import attention_softmax from ipex_llm.transformers.models.common import attention_softmax
from transformers import AutoProcessor from transformers import AutoProcessor, TextIteratorStreamer
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
@ -111,6 +118,38 @@ def _in_projection_packed(
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
# for minicpm-v-2_6 benchmarking purposes
def minicpmv_decode_stream_wrapper(origin_decode_stream):
def minicpv_decode_stream(
self,
inputs_embeds,
tokenizer,
**kwargs
):
streamer = kwargs.get('streamer', None)
if streamer is not None:
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
generation_kwargs = {
'inputs_embeds': inputs_embeds,
'pad_token_id': 0,
'eos_token_id': terminators,
}
generation_kwargs.update(kwargs)
thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
thread.start()
return streamer
else:
return origin_decode_stream(
self=self,
inputs_embeds=inputs_embeds,
tokenizer=tokenizer,
**kwargs
)
return minicpv_decode_stream
# MiniCPM-V-2 # MiniCPM-V-2
# modified from timm.models.vision_transformer.Attention.forward # modified from timm.models.vision_transformer.Attention.forward
def vision_transformer_attention_forward(self, x: torch.Tensor) -> torch.Tensor: def vision_transformer_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
@ -209,6 +248,12 @@ def minicpmv_generate_wrapper(origin_generate):
**kwargs **kwargs
): ):
RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call
# for minicpm-v-2_6 benchmarking purposes
stream = kwargs.get("stream", False)
if isinstance(stream, TextIteratorStreamer):
kwargs.update({'streamer': stream})
return origin_generate( return origin_generate(
*inputs, *inputs,
**kwargs, **kwargs,