diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index d43043d3..a7da4efc 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1849,6 +1849,11 @@ def _optimize_post(model, lightweight_bmm=False): # MiniCPM-V 2.6 from ipex_llm.transformers.models.minicpmv import siglip_attention_forward convert_forward(model.vpm, vpm_module.SiglipAttention, siglip_attention_forward) + + from ipex_llm.transformers.models.minicpmv import _in_projection_packed + resampler_module_name = model.resampler.__class__.__module__ + resampler_module = importlib.import_module(resampler_module_name) + resampler_module._in_projection_packed = _in_projection_packed elif model.vpm.config.model_type == "idefics2": # MiniCPM-V 2.5 from ipex_llm.transformers.models.minicpmv import siglip_attention_forward diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index 32fba2d4..638b92c4 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -17,7 +17,8 @@ import math import torch -from typing import Optional +from typing import Optional, List +from torch.nn.functional import linear from ipex_llm.transformers.models.common import merge_qkv_base from ipex_llm.transformers.models.common import attention_softmax from transformers import AutoProcessor @@ -61,6 +62,55 @@ def siglip_attention_forward( return attn_output, attn_weights +# MiniCPM-V-2_6 +def _in_projection_packed( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: Optional[torch.Tensor] = None, +) -> List[torch.Tensor]: + E = q.size(-1) + if k is v: + if q is k: + # self-attention + proj = linear(q, w, b) + # reshape to 3, E and not E, 3 is deliberate for + # better memory coalescing and keeping same order as chunk() + proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2) + proj = proj.contiguous() + return proj[0], proj[1], proj[2] + else: + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + q_proj = linear(q, w_q, b_q) + kv_proj = linear(k, w_kv, b_kv) + # reshape to 2, E and not E, 2 is deliberate for + # better memory coalescing and keeping same order as chunk() + kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2) + kv_proj = kv_proj.contiguous() + return (q_proj, kv_proj[0], kv_proj[1]) + else: + w_q, w_k, w_v = w.chunk(3) + # ipex-llm changes start: add contiguous to workaround a ipex bug + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + w_q = w_q.contiguous() + w_k = w_k.contiguous() + w_v = w_v.contiguous() + # ipex-llm changes end + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + # MiniCPM-V-2 # modified from timm.models.vision_transformer.Attention.forward def vision_transformer_attention_forward(self, x: torch.Tensor) -> torch.Tensor: