fix minicpm-v-2 fp16 (#11819)

This commit is contained in:
Yishuo Wang 2024-08-15 18:34:40 +08:00 committed by GitHub
parent 6543321f04
commit 750d4ad5dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 0 deletions

View file

@ -1735,7 +1735,9 @@ def _optimize_post(model, lightweight_bmm=False):
vpm_module = importlib.import_module(vpm_modeling_module_name)
if not hasattr(model.vpm, "config"):
# MiniCPM-V 2
from ipex_llm.transformers.models.minicpmv import vision_transformer_attention_forward
from ipex_llm.transformers.models.minicpmv import minicpmv_get_vision_embedding
convert_forward(model.vpm, vpm_module.Attention, vision_transformer_attention_forward)
model.get_vision_embedding = MethodType(minicpmv_get_vision_embedding, model)
elif model.vpm.config.model_type == "siglip":
# MiniCPM-V 2.6

View file

@ -61,6 +61,29 @@ def siglip_attention_forward(
return attn_output, attn_weights
# MiniCPM-V-2
# modified from timm.models.vision_transformer.Attention.forward
def vision_transformer_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
bsz, q_len, hidden_size = x.size()
qkv = self.qkv(x)
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.chunk(3, dim=1)
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = self.attn_drop(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
attn_output = self.proj(attn_output)
attn_output = self.proj_drop(attn_output)
return attn_output
# MiniCPM-V-2_5
def minicpmv_chat_wrapper(origin_chat):
def minicpmv_chat(