diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 385dfae5..8996f8a3 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1997,6 +1997,11 @@ def _optimize_post(model, lightweight_bmm=False): resampler_module_name = model.resampler.__class__.__module__ resampler_module = importlib.import_module(resampler_module_name) resampler_module._in_projection_packed = _in_projection_packed + + # for minicpm-v-2_6 benchmarking purposes + from ipex_llm.transformers.models.minicpmv import minicpmv_decode_stream_wrapper + minicpmv_decode_stream = minicpmv_decode_stream_wrapper(module.MiniCPMV._decode_stream) + model._decode_stream = MethodType(minicpmv_decode_stream, model) elif model.vpm.config.model_type == "idefics2": # MiniCPM-V 2.5 from ipex_llm.transformers.models.minicpmv import siglip_attention_forward diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index 638b92c4..89aca6d0 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -13,15 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# Some parts of this file is adapted from +# https://huggingface.co/openbmb/MiniCPM-V-2_6/blob/main/modeling_minicpmv.py +# which is licensed under Apache License 2.0: +# +# https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE +# import math import torch +from threading import Thread from typing import Optional, List from torch.nn.functional import linear from ipex_llm.transformers.models.common import merge_qkv_base from ipex_llm.transformers.models.common import attention_softmax -from transformers import AutoProcessor +from transformers import AutoProcessor, TextIteratorStreamer from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor @@ -111,6 +118,38 @@ def _in_projection_packed( return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) +# for minicpm-v-2_6 benchmarking purposes +def minicpmv_decode_stream_wrapper(origin_decode_stream): + def minicpv_decode_stream( + self, + inputs_embeds, + tokenizer, + **kwargs + ): + streamer = kwargs.get('streamer', None) + if streamer is not None: + terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] + generation_kwargs = { + 'inputs_embeds': inputs_embeds, + 'pad_token_id': 0, + 'eos_token_id': terminators, + } + generation_kwargs.update(kwargs) + + thread = Thread(target=self.llm.generate, kwargs=generation_kwargs) + thread.start() + + return streamer + else: + return origin_decode_stream( + self=self, + inputs_embeds=inputs_embeds, + tokenizer=tokenizer, + **kwargs + ) + return minicpv_decode_stream + + # MiniCPM-V-2 # modified from timm.models.vision_transformer.Attention.forward def vision_transformer_attention_forward(self, x: torch.Tensor) -> torch.Tensor: @@ -209,6 +248,12 @@ def minicpmv_generate_wrapper(origin_generate): **kwargs ): RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call + + # for minicpm-v-2_6 benchmarking purposes + stream = kwargs.get("stream", False) + if isinstance(stream, TextIteratorStreamer): + kwargs.update({'streamer': stream}) + return origin_generate( *inputs, **kwargs,