diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 23c006ed..ac6081e1 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1735,7 +1735,9 @@ def _optimize_post(model, lightweight_bmm=False): vpm_module = importlib.import_module(vpm_modeling_module_name) if not hasattr(model.vpm, "config"): # MiniCPM-V 2 + from ipex_llm.transformers.models.minicpmv import vision_transformer_attention_forward from ipex_llm.transformers.models.minicpmv import minicpmv_get_vision_embedding + convert_forward(model.vpm, vpm_module.Attention, vision_transformer_attention_forward) model.get_vision_embedding = MethodType(minicpmv_get_vision_embedding, model) elif model.vpm.config.model_type == "siglip": # MiniCPM-V 2.6 diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index f25801be..32fba2d4 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -61,6 +61,29 @@ def siglip_attention_forward( return attn_output, attn_weights +# MiniCPM-V-2 +# modified from timm.models.vision_transformer.Attention.forward +def vision_transformer_attention_forward(self, x: torch.Tensor) -> torch.Tensor: + bsz, q_len, hidden_size = x.size() + + qkv = self.qkv(x) + qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.chunk(3, dim=1) + + attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3)) + attn_weights = attention_softmax(attn_weights, self.training) + attn_weights = self.attn_drop(attn_weights) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, hidden_size) + + attn_output = self.proj(attn_output) + attn_output = self.proj_drop(attn_output) + return attn_output + + # MiniCPM-V-2_5 def minicpmv_chat_wrapper(origin_chat): def minicpmv_chat(